diff --git a/Project.toml b/Project.toml index 8990f70..8d046d6 100644 --- a/Project.toml +++ b/Project.toml @@ -2,9 +2,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" version = "0.4.7" -[workspace] -projects = ["test", "scrap"] - [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -13,6 +10,7 @@ OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -22,6 +20,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] MTKExt = ["Symbolics", "ModelingToolkit"] + [compat] Accessors = "0.1" ConstructionBase = "1.5" @@ -31,6 +30,7 @@ OhMyThreads = "0.6, 0.7, 0.8" OrderedCollections = "1.6.3" RecursiveArrayTools = "3" SciMLBase = "2" +SciMLStructures = "1.7.0" SparseArrays = "1" SymbolicIndexingInterface = "0.3" Symbolics = "6" @@ -41,3 +41,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test"] + +[workspace] +projects = ["test", "scrap"] diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index ff9739f..39b698f 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -129,6 +129,12 @@ using DiffEqBase: DiffEqBase, anyeltypedual +using SciMLStructures: + SciMLStructures, + Tunable, + Constants, + canonicalize, + replace #---------------------------------------------------------- # Random utils include("utils.jl") diff --git a/src/graph_system.jl b/src/graph_system.jl index 1178063..6a0cc73 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -111,7 +111,7 @@ function PartitionedGraphSystem(g::GraphSystem) in the graph, then we'd end up with - nodes_paritioned = [SysType1[n1, n2], SysType1[n3]] + nodes_paritioned = [SysType1[n1, n2], SysType2[n3]] ===================================================================================================# diff --git a/src/problems.jl b/src/problems.jl index dc7449b..1dd3197 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -213,3 +213,41 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete (; f, u, tspan, p, callback, tstops) end + +SciMLStructures.isscimlstructure(::GraphSystemParameters) = true +SciMLStructures.ismutablescimlstructure(::GraphSystemParameters) = false +SciMLStructures.hasportion(::Tunable, ::GraphSystemParameters) = true + +function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) + paramvals = map(p.params_partitioned) do paramclass + vals = map(paramclass) do paramobj + collect(values(NamedTuple(paramobj))) + end + reduce(vcat, vals) + end + buffer = reduce(vcat, paramvals) + + repack = let p = p + function (newbuffer) + replace(Tunable(), p, newbuffer) + end + end + buffer, repack, false +end + +function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer) + np = copy(p) + np_part = let i = 1 + map(np.params_partitioned) do paramclass + for j in 1:length(paramclass) + obj = paramclass[j] + syms = keys(NamedTuple(obj)) + paramclass[j] = set_param_prop(obj, (; (syms .=> view(newbuffer, i:i+length(syms)-1))...)) + i += length(syms) + end + end + @assert length(newbuffer) == i - 1 + end + @set np.params_partitioned = np_part + np +end diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 2ede123..c949db0 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -231,3 +231,25 @@ function sensitivity_test() end end end + +using Mooncake, DifferentiationInterface, Enzyme, SciMLSensitivity +import SciMLStructures as SS + +function autodiff_test() + function sum_test(p; vjp = EnzymeVJP()) + prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) + buffer, repack, b = SS.canonicalize(SS.Tunable(), prob.p) + newp = repack(p) + prob = remake(prob; p = newp) + + sol = DiffEqBase.solve(prob, Tsit5(), saveat = 0.:0.5:10., sensealg = GaussAdjoint(; autojacvec = vjp)) + return sum(sol.u[end]) + end + + params = [1., 1., 2., 1., 3., 1., 0.] + @test_nowarn sum_test(params) + @test_nowarn sum_test(params, vjp = SciMLSensitivity.MooncakeVJP()) + + @test_nowarn value_and_gradient(sum_test, AutoEnzyme(), params) + @test_nowarn value_and_gradient(sum_test, AutoMooncake(), params) +end