@@ -100,9 +100,159 @@ end
100100 maybe_codegen_scimlproblem (expression, SteadyStateProblem{iip}, args; kwargs... )
101101end
102102
103+ @fallback_iip_specialize function SemilinearODEFunction {iip, specialize} (
104+ sys:: System ; u0 = nothing , p = nothing , t = nothing ,
105+ semiquadratic_form = nothing ,
106+ stiff_linear = true , stiff_quadratic = false , stiff_C = false ,
107+ eval_expression = false , eval_module = @__MODULE__ ,
108+ expression = Val{false }, sparse = false , check_compatibility = true ,
109+ jac = false , checkbounds = false , cse = true , initialization_data = nothing ,
110+ analytic = nothing , kwargs... ) where {iip, specialize}
111+ check_complete (sys, SemilinearODEFunction)
112+ check_compatibility && check_compatible_system (SemilinearODEFunction, sys)
113+
114+ if semiquadratic_form === nothing
115+ semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
116+ sys = add_semiquadratic_parameters (sys, semiquadratic_form... )
117+ end
118+
119+ A, B, C = semiquadratic_form
120+ M = calculate_massmatrix (sys)
121+ _M = concrete_massmatrix (M; sparse, u0)
122+ dvs = unknowns (sys)
123+
124+ f1, f2 = generate_semiquadratic_functions (
125+ sys, A, B, C; stiff_linear, stiff_quadratic, stiff_C, expression, wrap_gfw = Val{true },
126+ eval_expression, eval_module, kwargs... )
127+
128+ if jac
129+ Cjac = (C === nothing || ! stiff_C) ? nothing : Symbolics. jacobian (C, dvs)
130+ _jac = generate_semiquadratic_jacobian (
131+ sys, A, B, C, Cjac; sparse, expression,
132+ wrap_gfw = Val{true }, eval_expression, eval_module, kwargs... )
133+ _W_sparsity = get_semiquadratic_W_sparsity (
134+ sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_C, mm = M)
135+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
136+ else
137+ _jac = nothing
138+ W_prototype = nothing
139+ end
140+
141+ observedfun = ObservedFunctionCache (
142+ sys; expression, steady_state = false , eval_expression, eval_module, checkbounds, cse)
143+
144+ args = (; f1)
145+ kwargs = (; jac = _jac, jac_prototype = W_prototype)
146+ f1 = maybe_codegen_scimlfn (expression, ODEFunction{iip, specialize}, args; kwargs... )
147+
148+ args = (; f1, f2)
149+ kwargs = (;
150+ sys = sys,
151+ jac = _jac,
152+ mass_matrix = _M,
153+ jac_prototype = W_prototype,
154+ observed = observedfun,
155+ analytic,
156+ initialization_data)
157+
158+ return maybe_codegen_scimlfn (
159+ expression, SplitFunction{iip, specialize}, args; kwargs... )
160+ end
161+
162+ @fallback_iip_specialize function SemilinearODEProblem {iip, spec} (
163+ sys:: System , op, tspan; check_compatibility = true , u0_eltype = nothing ,
164+ expression = Val{false }, callback = nothing , sparse = false ,
165+ stiff_linear = true , stiff_quadratic = false , stiff_C = false , jac = false , kwargs... ) where {
166+ iip, spec}
167+ check_complete (sys, SemilinearODEProblem)
168+ check_compatibility && check_compatible_system (SemilinearODEProblem, sys)
169+
170+ A, B, C = semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
171+ eqs = equations (sys)
172+ dvs = unknowns (sys)
173+
174+ sys = add_semiquadratic_parameters (sys, A, B, C)
175+ if A != = nothing
176+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
177+ else
178+ linear_matrix_param = nothing
179+ end
180+ if B != = nothing
181+ quadratic_forms = [unwrap (getproperty (sys, get_quadratic_form_name (i)))
182+ for i in 1 : length (eqs)]
183+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
184+ else
185+ quadratic_forms = diffcache_par = nothing
186+ end
187+
188+ op = to_varmap (op, dvs)
189+ floatT = calculate_float_type (op, typeof (op))
190+ _u0_eltype = something (u0_eltype, floatT)
191+
192+ guess = copy (guesses (sys))
193+ defs = copy (defaults (sys))
194+ if A != = nothing
195+ guess[linear_matrix_param] = fill (NaN , size (A))
196+ defs[linear_matrix_param] = A
197+ end
198+ if B != = nothing
199+ for (par, mat) in zip (quadratic_forms, B)
200+ guess[par] = fill (NaN , size (mat))
201+ defs[par] = mat
202+ end
203+ cachelen = jac ? length (dvs) * length (eqs) : length (dvs)
204+ defs[diffcache_par] = DiffCache (zeros (DiffEqBase. value (_u0_eltype), cachelen))
205+ end
206+ @set! sys. guesses = guess
207+ @set! sys. defaults = defs
208+
209+ f, u0, p = process_SciMLProblem (SemilinearODEFunction{iip, spec}, sys, op;
210+ t = tspan != = nothing ? tspan[1 ] : tspan, expression, check_compatibility,
211+ semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_C, jac, kwargs... )
212+
213+ kwargs = process_kwargs (sys; expression, callback, kwargs... )
214+
215+ args = (; f, u0, tspan, p)
216+ maybe_codegen_scimlproblem (expression, SplitODEProblem{iip}, args; kwargs... )
217+ end
218+
219+ """
220+ $(TYPEDSIGNATURES)
221+
222+ Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
223+ `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
224+ """
225+ function add_semiquadratic_parameters (sys:: System , A, B, C)
226+ eqs = equations (sys)
227+ n = length (eqs)
228+ var_to_name = copy (get_var_to_name (sys))
229+ if B != = nothing
230+ for i in eachindex (B)
231+ B[i] === nothing && continue
232+ par = get_quadratic_form_param ((n, n), i)
233+ var_to_name[get_quadratic_form_name (i)] = par
234+ sys = with_additional_constant_parameter (sys, par)
235+ end
236+ par = get_diffcache_param (Float64)
237+ var_to_name[DIFFCACHE_PARAM_NAME] = par
238+ sys = with_additional_nonnumeric_parameter (sys, par)
239+ end
240+ if A != = nothing
241+ par = get_linear_matrix_param ((n, n))
242+ var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
243+ sys = with_additional_constant_parameter (sys, par)
244+ end
245+ @set! sys. var_to_name = var_to_name
246+ if get_parent (sys) != = nothing
247+ @set! sys. parent = add_semiquadratic_parameters (get_parent (sys), A, B, C)
248+ end
249+ return sys
250+ end
251+
103252function check_compatible_system (
104253 T:: Union {Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
105- Type{DAEProblem}, Type{SteadyStateProblem}},
254+ Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
255+ Type{SemilinearODEProblem}},
106256 sys:: System )
107257 check_time_dependent (sys, T)
108258 check_not_dde (sys)
0 commit comments