Skip to content

Commit 1534921

Browse files
Merge pull request #503 from ChrisRackauckas-Claude/pr-464-rebased
feat: add support for OverrideInit and CheckInit (rebased)
2 parents aa00ca0 + 5531c60 commit 1534921

File tree

15 files changed

+308
-166
lines changed

15 files changed

+308
-166
lines changed

Project.toml

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,69 @@ authors = ["Chris Rackauckas <[email protected]>"]
44
version = "4.28.0"
55

66
[deps]
7+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
79
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
810
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
911
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1012
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1113
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1215
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1316
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1417
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1518
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1619
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1720
Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164"
21+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1822

1923
[compat]
24+
Accessors = "0.1.38"
25+
ADTypes = "1"
2026
AlgebraicMultigrid = "1"
2127
Aqua = "0.8"
28+
ArrayInterface = "7.17.1"
2229
CEnum = "0.5"
2330
DAEProblemLibrary = "0.1"
2431
DataStructures = "0.18, 0.19"
2532
DiffEqBase = "6.154"
2633
DiffEqCallbacks = "4"
34+
DifferentiationInterface = "0.6, 0.7"
2735
ExplicitImports = "1"
2836
ForwardDiff = "0.10"
2937
IncompleteLU = "0.2"
3038
Libdl = "1"
3139
LinearAlgebra = "1"
40+
LinearSolve = "3.40.0"
3241
Logging = "1"
3342
ModelingToolkit = "10"
3443
ODEProblemLibrary = "1"
3544
PrecompileTools = "1"
3645
Reexport = "1.0"
37-
SciMLBase = "2.9"
46+
SafeTestsets = "0.1"
47+
SciMLBase = "2.119.0"
3848
SparseArrays = "1"
39-
SparseConnectivityTracer = "0.6"
40-
SparseDiffTools = "2"
49+
SparseConnectivityTracer = "1"
4150
Sundials_jll = "7.4.1"
51+
SymbolicIndexingInterface = "0.3.35"
4252
Test = "1"
4353
julia = "1.10"
4454

4555
[extras]
56+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4657
AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c"
4758
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4859
DAEProblemLibrary = "dfb8ca35-80a1-48ba-a605-84916a45b4f8"
4960
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
61+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
5062
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
5163
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5264
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
5365
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
5466
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
67+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5568
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
56-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
5769
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5870

5971
[targets]
60-
test = ["Test", "AlgebraicMultigrid", "Aqua", "DiffEqCallbacks", "ExplicitImports", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit"]
72+
test = ["Test", "ADTypes", "AlgebraicMultigrid", "Aqua", "DiffEqCallbacks", "ExplicitImports", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "DifferentiationInterface", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit", "SafeTestsets"]

analyze_imports.jl

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/Sundials.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@ using DiffEqBase: DiffEqBase, NonlinearFunction, ODEFunction, add_saveat!,
1313
update_coefficients!, warn_compat
1414
using SciMLBase: AbstractSciMLOperator, DAEProblem, ODEProblem, ReturnCode,
1515
SciMLBase, SplitODEProblem, VectorContinuousCallback
16+
import Accessors: @reset
17+
import ArrayInterface
18+
import SymbolicIndexingInterface as SII
19+
import SymbolicIndexingInterface: ParameterIndexingProxy
1620
using DataStructures: DataStructures
1721
using Logging: Logging
1822
using SparseArrays: SparseArrays
1923
using LinearAlgebra: LinearAlgebra
2024

25+
import LinearSolve # Required for initialization
2126
using Libdl: Libdl
2227
using CEnum: CEnum, @cenum
2328

@@ -91,6 +96,7 @@ include("common_interface/verbosity.jl")
9196
include("common_interface/algorithms.jl")
9297
include("common_interface/integrator_types.jl")
9398
include("common_interface/integrator_utils.jl")
99+
include("common_interface/initialize_dae.jl")
94100
include("common_interface/solve.jl")
95101

96102
import PrecompileTools
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
struct SundialsDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end
2+
3+
function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initializealg = integrator.initializealg)
4+
_initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
5+
end
6+
7+
struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
8+
end
9+
10+
function _initialize_dae!(integrator::IDAIntegrator, prob,
11+
initializealg::IDADefaultInit, isinplace)
12+
if integrator.u_modified
13+
IDAReinit!(integrator)
14+
end
15+
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
16+
tstart, tend = integrator.sol.prob.tspan
17+
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
18+
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
19+
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
20+
end
21+
if integrator.alg.init_all
22+
init_type = IDA_Y_INIT
23+
else
24+
init_type = IDA_YA_YDP_INIT
25+
# Use preallocated NVector for differential_vars
26+
if integrator.diff_vars_nvec !== nothing
27+
integrator.flag = IDASetId(integrator.mem, integrator.diff_vars_nvec)
28+
else
29+
error("differential_vars NVector not preallocated but needed for IDASetId")
30+
end
31+
end
32+
dt = integrator.dt == tstart ? tend : integrator.dt
33+
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)
34+
35+
# Reflect consistent initial conditions back into the integrator's
36+
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
37+
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
38+
end
39+
if integrator.t == tstart && integrator.flag < 0
40+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
41+
ReturnCode.InitialFailure)
42+
end
43+
end
44+
45+
function _initialize_dae!(integrator, prob, ::SundialsDefaultInit, isinplace)
46+
if SciMLBase.has_initializeprob(prob.f)
47+
_initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
48+
elseif integrator isa IDAIntegrator
49+
_initialize_dae!(integrator, prob, IDADefaultInit(), isinplace)
50+
end
51+
end
52+
53+
function _initialize_dae!(integrator, prob, initalg::SciMLBase.NoInit, isinplace) end
54+
55+
function _initialize_dae!(integrator, prob, initalg::SciMLBase.OverrideInit, isinplace::Union{Val{true}, Val{false}})
56+
nlsolve_alg = KINSOL()
57+
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
58+
59+
if isinplace === Val{true}()
60+
integrator.u .= u0
61+
if length(integrator.sol.u) == 1
62+
integrator.sol.u[1] .= u0
63+
end
64+
else
65+
integrator.u = u0
66+
if length(integrator.sol.u) == 1
67+
integrator.sol.u[1] = u0
68+
end
69+
end
70+
integrator.p = p
71+
sol = integrator.sol
72+
@reset sol.prob.p = integrator.p
73+
integrator.sol = sol
74+
75+
# For IDA, we need to reinitialize the solver after changing u, du, or p
76+
if integrator isa IDAIntegrator && success
77+
integrator.u_modified = true
78+
end
79+
80+
if !success
81+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure)
82+
end
83+
end
84+
85+
function _initialize_dae!(integrator, prob, initalg::SciMLBase.CheckInit, isinplace::Union{Val{true}, Val{false}})
86+
SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; abstol = integrator.opts.abstol)
87+
end

src/common_interface/integrator_types.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ mutable struct CVODEIntegrator{N,
4040
oType,
4141
LStype,
4242
Atype,
43-
CallbackCacheType} <: AbstractSundialsIntegrator{algType}
43+
CallbackCacheType,
44+
IA} <: AbstractSundialsIntegrator{algType}
4445
u::Array{Float64, N}
4546
u_nvec::NVector
4647
p::pType
@@ -66,6 +67,7 @@ mutable struct CVODEIntegrator{N,
6667
vector_event_last_time::Int
6768
callback_cache::CallbackCacheType
6869
last_event_error::Float64
70+
initializealg::IA
6971
ctx_handle::ContextHandle
7072
end
7173

@@ -102,7 +104,8 @@ mutable struct ARKODEIntegrator{N,
102104
MLStype,
103105
Mtype,
104106
CallbackCacheType,
105-
MemType} <: AbstractSundialsIntegrator{ARKODE}
107+
MemType,
108+
IA} <: AbstractSundialsIntegrator{ARKODE}
106109
u::Array{Float64, N}
107110
u_nvec::NVector
108111
p::pType
@@ -130,15 +133,16 @@ mutable struct ARKODEIntegrator{N,
130133
vector_event_last_time::Int
131134
callback_cache::CallbackCacheType
132135
last_event_error::Float64
136+
initializealg::IA
133137
ctx_handle::ContextHandle
134138
end
135139

136140
function (integrator::ARKODEIntegrator{
137141
N, pType, solType, algType, fType, UFType, JType, oType,
138-
LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem})(t::Number,
142+
LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA})(t::Number,
139143
deriv::Type{Val{T}} = Val{0};
140144
idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType,
141-
LStype, Atype, MLStype, Mtype, CallbackCacheType, T}
145+
LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T}
142146
out = similar(integrator.u)
143147
out_nvec = NVector(vec(out), integrator.ctx_handle.ctx)
144148
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out_nvec)
@@ -148,10 +152,10 @@ end
148152

149153
function (integrator::ARKODEIntegrator{
150154
N, pType, solType, algType, fType, UFType, JType, oType,
151-
LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem})(t::Number,
155+
LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA})(t::Number,
152156
deriv::Type{Val{T}} = Val{0};
153157
idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType,
154-
LStype, Atype, MLStype, Mtype, CallbackCacheType, T}
158+
LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T}
155159
out = similar(integrator.u)
156160
out_nvec = NVector(vec(out), integrator.ctx_handle.ctx)
157161
integrator.flag = @checkflag ERKStepGetDky(integrator.mem, t, Cint(T), out_nvec)
@@ -161,11 +165,11 @@ end
161165

162166
function (integrator::ARKODEIntegrator{
163167
N, pType, solType, algType, fType, UFType, JType, oType,
164-
LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem})(out,
168+
LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA})(out,
165169
t::Number,
166170
deriv::Type{Val{T}} = Val{0};
167171
idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType,
168-
LStype, Atype, MLStype, Mtype, CallbackCacheType, T}
172+
LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T}
169173
out_nvec = NVector(vec(out), integrator.ctx_handle.ctx)
170174
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out_nvec)
171175
copyto!(out, out_nvec.v)
@@ -174,11 +178,11 @@ end
174178

175179
function (integrator::ARKODEIntegrator{
176180
N, pType, solType, algType, fType, UFType, JType, oType,
177-
LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem})(out,
181+
LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA})(out,
178182
t::Number,
179183
deriv::Type{Val{T}} = Val{0};
180184
idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType,
181-
LStype, Atype, MLStype, Mtype, CallbackCacheType, T}
185+
LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T}
182186
out_nvec = NVector(vec(out), integrator.ctx_handle.ctx)
183187
integrator.flag = @checkflag ERKStepGetDky(integrator.mem, t, Cint(T), out_nvec)
184188
copyto!(out, out_nvec.v)

src/common_interface/integrator_utils.jl

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ end
157157

158158
function handle_callback_modifiers!(integrator::IDAIntegrator)
159159
# Implicitly does IDAReinit!
160-
DiffEqBase.initialize_dae!(integrator, IDADefaultInit())
160+
DiffEqBase.initialize_dae!(integrator)
161161
end
162162

163163
function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator, t)
@@ -211,6 +211,8 @@ end
211211
@inline function Base.getproperty(integrator::AbstractSundialsIntegrator, sym::Symbol)
212212
if sym == :dt
213213
return integrator.t - integrator.tprev
214+
elseif sym == :ps
215+
return ParameterIndexingProxy(integrator)
214216
else
215217
return getfield(integrator, sym)
216218
end
@@ -228,46 +230,6 @@ end
228230
# Required for callbacks
229231
DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing
230232

231-
DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing
232-
233-
struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
234-
end
235-
236-
function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
237-
initializealg::IDADefaultInit)
238-
if integrator.u_modified
239-
IDAReinit!(integrator)
240-
end
241-
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
242-
tstart, tend = integrator.sol.prob.tspan
243-
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
244-
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
245-
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
246-
end
247-
if integrator.alg.init_all
248-
init_type = IDA_Y_INIT
249-
else
250-
init_type = IDA_YA_YDP_INIT
251-
# Use preallocated NVector for differential_vars
252-
if integrator.diff_vars_nvec !== nothing
253-
integrator.flag = IDASetId(integrator.mem, integrator.diff_vars_nvec)
254-
else
255-
error("differential_vars NVector not preallocated but needed for IDASetId")
256-
end
257-
end
258-
dt = integrator.dt == tstart ? tend : integrator.dt
259-
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)
260-
261-
# Reflect consistent initial conditions back into the integrator's
262-
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
263-
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
264-
end
265-
if integrator.t == tstart && integrator.flag < 0
266-
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
267-
ReturnCode.InitialFailure)
268-
end
269-
end
270-
271233
DiffEqBase.has_reinit(integrator::AbstractSundialsIntegrator) = true
272234
function DiffEqBase.reinit!(integrator::AbstractSundialsIntegrator,
273235
u0 = integrator.sol.prob.u0;
@@ -343,3 +305,6 @@ DiffEqBase.get_tstops_array(integ::AbstractSundialsIntegrator) = get_tstops(inte
343305
function DiffEqBase.get_tstops_max(integ::AbstractSundialsIntegrator)
344306
maximum(get_tstops_array(integ))
345307
end
308+
309+
# SII
310+
SII.symbolic_container(integ::AbstractSundialsIntegrator) = integ.sol

0 commit comments

Comments
 (0)