From c46249e5e28a12adb8600e5fe2e8689218310cfb Mon Sep 17 00:00:00 2001 From: vyudu Date: Tue, 12 Aug 2025 15:27:38 -0400 Subject: [PATCH 1/4] feat: implement for GraphSystemParameters --- Project.toml | 9 ++++++--- src/GraphDynamics.jl | 6 ++++++ src/graph_system.jl | 2 +- src/problems.jl | 38 ++++++++++++++++++++++++++++++++++++ test/particle_osc_example.jl | 19 ++++++++++++++++++ test/runtests.jl | 4 ++++ 6 files changed, 74 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 8990f70..8d046d6 100644 --- a/Project.toml +++ b/Project.toml @@ -2,9 +2,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" version = "0.4.7" -[workspace] -projects = ["test", "scrap"] - [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -13,6 +10,7 @@ OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -22,6 +20,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] MTKExt = ["Symbolics", "ModelingToolkit"] + [compat] Accessors = "0.1" ConstructionBase = "1.5" @@ -31,6 +30,7 @@ OhMyThreads = "0.6, 0.7, 0.8" OrderedCollections = "1.6.3" RecursiveArrayTools = "3" SciMLBase = "2" +SciMLStructures = "1.7.0" SparseArrays = "1" SymbolicIndexingInterface = "0.3" Symbolics = "6" @@ -41,3 +41,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test"] + +[workspace] +projects = ["test", "scrap"] diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index ff9739f..39b698f 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -129,6 +129,12 @@ using DiffEqBase: DiffEqBase, anyeltypedual +using SciMLStructures: + SciMLStructures, + Tunable, + Constants, + canonicalize, + replace #---------------------------------------------------------- # Random utils include("utils.jl") diff --git a/src/graph_system.jl b/src/graph_system.jl index 1178063..6a0cc73 100644 --- a/src/graph_system.jl +++ b/src/graph_system.jl @@ -111,7 +111,7 @@ function PartitionedGraphSystem(g::GraphSystem) in the graph, then we'd end up with - nodes_paritioned = [SysType1[n1, n2], SysType1[n3]] + nodes_paritioned = [SysType1[n1, n2], SysType2[n3]] ===================================================================================================# diff --git a/src/problems.jl b/src/problems.jl index dc7449b..183fc7a 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -213,3 +213,41 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete (; f, u, tspan, p, callback, tstops) end + +SciMLStructures.isscimlstructure(::GraphSystemParameters) = true +SciMLStructures.ismutablescimlstructure(::GraphSystemParameters) = false +SciMLStructures.hasportion(::Tunable, ::GraphSystemParameters) = true + +function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) + paramvals = map(Iterators.flatten(p.params_partitioned)) do paramobj + values(NamedTuple(params)) + end + buffer = reduce(vcat(paramvals)) + + repack = let p = p + function repack(newbuffer) + replace(Tunable(), p, newbuffer) + end + end + buffer, repack, false +end + +function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer) + paramobjs = Iterators.flatten(p.params_partitioned) + N = sum([length(NamedTuple(obj)) for obj in paramobjs]) + @assert length(newbuffer) == N + + idx = 1 + new_params = map(paramobjs) do paramobj + syms = keys(NamedTuple(paramobj)) + newparams = typeof(paramobj)(; (syms .=> view(newbuffer, idx:idx+length(syms)-1))...) + idx += length(syms) + end + param_types = (unique ∘ imap)(typeof, new_params) + params_partitioned = Tuple(map(param_types) do T + filter(new_params) do p + p isa T + end + end) + @set p.params_partitioned = params_partitioned +end diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 2ede123..7b608e2 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -231,3 +231,22 @@ function sensitivity_test() end end end + +using Mooncake, DifferentiationInterface, Enzyme, SciMLSensitivity + +function autodiff_test() + function f(y::Array{Float64}, u0::Array{Float64}; vjp = EnzymeVJP()) + tspan = (0.0, 3.0) + prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) + sol = DiffEqBase.solve(prob, Tsit5(), saveat = 0.:0.5:10., sensealg = GaussAdjoint(; autojacvec = vjp)) + y .= sol[1,:] + return nothing + end; + + d_u0 = zeros(6) + u0 = [1., 0., 0., -1., 0., 0.] + dy = zeros(21) + y = zeros(21) + + Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3d677f5..55d0f02 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,3 +5,7 @@ using SafeTestsets solution_solve_test() sensitivity_test() end + +@safetestset "Autodiff" begin + include("autodiff.jl") +end From 547b23045f273fa7e26f4d20fc824684c9ce8221 Mon Sep 17 00:00:00 2001 From: vyudu Date: Tue, 12 Aug 2025 16:17:31 -0400 Subject: [PATCH 2/4] fix: fix return of canonicalize and replace --- src/problems.jl | 10 ++++++---- test/particle_osc_example.jl | 8 ++++---- test/runtests.jl | 4 ---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/problems.jl b/src/problems.jl index 183fc7a..9bf8124 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -220,12 +220,12 @@ SciMLStructures.hasportion(::Tunable, ::GraphSystemParameters) = true function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) paramvals = map(Iterators.flatten(p.params_partitioned)) do paramobj - values(NamedTuple(params)) + collect(values(NamedTuple(paramobj))) end - buffer = reduce(vcat(paramvals)) + buffer = reduce(vcat, paramvals) repack = let p = p - function repack(newbuffer) + function (newbuffer) replace(Tunable(), p, newbuffer) end end @@ -239,9 +239,11 @@ function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer) idx = 1 new_params = map(paramobjs) do paramobj + Main.xx[] = paramobj syms = keys(NamedTuple(paramobj)) - newparams = typeof(paramobj)(; (syms .=> view(newbuffer, idx:idx+length(syms)-1))...) + newparams = SubsystemParams{get_tag(paramobj)}(; (syms .=> view(newbuffer, idx:idx+length(syms)-1))...) idx += length(syms) + newparams end param_types = (unique ∘ imap)(typeof, new_params) params_partitioned = Tuple(map(param_types) do T diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 7b608e2..72927ac 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -241,12 +241,12 @@ function autodiff_test() sol = DiffEqBase.solve(prob, Tsit5(), saveat = 0.:0.5:10., sensealg = GaussAdjoint(; autojacvec = vjp)) y .= sol[1,:] return nothing - end; + end d_u0 = zeros(6) - u0 = [1., 0., 0., -1., 0., 0.] + u0 = [1., 0., -1.0, 0., 0., 1.0] dy = zeros(21) y = zeros(21) - - Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)) + @test_nowarn f(y, u0) + @test_nowarn f(y, u0, vjp = SciMLSensitivity.MooncakeVJP()) end diff --git a/test/runtests.jl b/test/runtests.jl index 55d0f02..3d677f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,3 @@ using SafeTestsets solution_solve_test() sensitivity_test() end - -@safetestset "Autodiff" begin - include("autodiff.jl") -end From f0a3720b937019fe0c7696699837eca60d49764e Mon Sep 17 00:00:00 2001 From: vyudu Date: Tue, 12 Aug 2025 16:53:23 -0400 Subject: [PATCH 3/4] create more sensible test case --- src/problems.jl | 3 +-- test/particle_osc_example.jl | 34 ++++++++++++++++++++++++---------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/problems.jl b/src/problems.jl index 9bf8124..d00904d 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -232,14 +232,13 @@ function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) buffer, repack, false end -function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer) +function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer)::GraphSystemParameters paramobjs = Iterators.flatten(p.params_partitioned) N = sum([length(NamedTuple(obj)) for obj in paramobjs]) @assert length(newbuffer) == N idx = 1 new_params = map(paramobjs) do paramobj - Main.xx[] = paramobj syms = keys(NamedTuple(paramobj)) newparams = SubsystemParams{get_tag(paramobj)}(; (syms .=> view(newbuffer, idx:idx+length(syms)-1))...) idx += length(syms) diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 72927ac..13a0b4d 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -233,20 +233,34 @@ function sensitivity_test() end using Mooncake, DifferentiationInterface, Enzyme, SciMLSensitivity +import SciMLStructures as SS function autodiff_test() - function f(y::Array{Float64}, u0::Array{Float64}; vjp = EnzymeVJP()) - tspan = (0.0, 3.0) + function sum_test(p; vjp = EnzymeVJP())::Float64 prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) + newp = SS.replace(SS.Tunable(), prob.p, p) + prob = remake(prob; p = newp) sol = DiffEqBase.solve(prob, Tsit5(), saveat = 0.:0.5:10., sensealg = GaussAdjoint(; autojacvec = vjp)) - y .= sol[1,:] - return nothing + return sum(sol.u[end]) end - - d_u0 = zeros(6) + + prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) + params = SS.canonicalize(SS.Tunable(), prob.p)[1] u0 = [1., 0., -1.0, 0., 0., 1.0] - dy = zeros(21) - y = zeros(21) - @test_nowarn f(y, u0) - @test_nowarn f(y, u0, vjp = SciMLSensitivity.MooncakeVJP()) + + @test_nowarn sum_test(params) + @test_nowarn sum_test(params, vjp = SciMLSensitivity.MooncakeVJP()) + + # ForwardDiff + ForwardDiff.gradient(sum_test, params) + value_and_gradient(sum_test, AutoEnzyme(), params) + + # Enzyme + dp = zeros(7) + Enzyme.autodiff(Reverse, sum_test, params, dp) + + # Mooncake + backend = AutoMooncake(; config=nothing) + prep = prepare_gradient(sum_test, backend, params) + Mooncake.gradient(sum_test, prep, backend, params) end From 6bdc4cab40e1fc712ff3d392e3e1a2cfea1e7741 Mon Sep 17 00:00:00 2001 From: vyudu Date: Wed, 13 Aug 2025 15:46:18 -0400 Subject: [PATCH 4/4] fix: fix type instability in replace --- src/problems.jl | 41 ++++++++++++++++++------------------ test/particle_osc_example.jl | 25 ++++++---------------- 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/src/problems.jl b/src/problems.jl index d00904d..1dd3197 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -218,9 +218,12 @@ SciMLStructures.isscimlstructure(::GraphSystemParameters) = true SciMLStructures.ismutablescimlstructure(::GraphSystemParameters) = false SciMLStructures.hasportion(::Tunable, ::GraphSystemParameters) = true -function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) - paramvals = map(Iterators.flatten(p.params_partitioned)) do paramobj - collect(values(NamedTuple(paramobj))) +function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) + paramvals = map(p.params_partitioned) do paramclass + vals = map(paramclass) do paramobj + collect(values(NamedTuple(paramobj))) + end + reduce(vcat, vals) end buffer = reduce(vcat, paramvals) @@ -232,23 +235,19 @@ function SciMLStructures.canonicalize(::Tunable, p::GraphSystemParameters) buffer, repack, false end -function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer)::GraphSystemParameters - paramobjs = Iterators.flatten(p.params_partitioned) - N = sum([length(NamedTuple(obj)) for obj in paramobjs]) - @assert length(newbuffer) == N - - idx = 1 - new_params = map(paramobjs) do paramobj - syms = keys(NamedTuple(paramobj)) - newparams = SubsystemParams{get_tag(paramobj)}(; (syms .=> view(newbuffer, idx:idx+length(syms)-1))...) - idx += length(syms) - newparams - end - param_types = (unique ∘ imap)(typeof, new_params) - params_partitioned = Tuple(map(param_types) do T - filter(new_params) do p - p isa T +function SciMLStructures.replace(::Tunable, p::GraphSystemParameters, newbuffer) + np = copy(p) + np_part = let i = 1 + map(np.params_partitioned) do paramclass + for j in 1:length(paramclass) + obj = paramclass[j] + syms = keys(NamedTuple(obj)) + paramclass[j] = set_param_prop(obj, (; (syms .=> view(newbuffer, i:i+length(syms)-1))...)) + i += length(syms) + end end - end) - @set p.params_partitioned = params_partitioned + @assert length(newbuffer) == i - 1 + end + @set np.params_partitioned = np_part + np end diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 13a0b4d..c949db0 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -236,31 +236,20 @@ using Mooncake, DifferentiationInterface, Enzyme, SciMLSensitivity import SciMLStructures as SS function autodiff_test() - function sum_test(p; vjp = EnzymeVJP())::Float64 + function sum_test(p; vjp = EnzymeVJP()) prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) - newp = SS.replace(SS.Tunable(), prob.p, p) + buffer, repack, b = SS.canonicalize(SS.Tunable(), prob.p) + newp = repack(p) prob = remake(prob; p = newp) + sol = DiffEqBase.solve(prob, Tsit5(), saveat = 0.:0.5:10., sensealg = GaussAdjoint(; autojacvec = vjp)) return sum(sol.u[end]) end - prob = particle_osc_prob(;x1=1.0, x2=-1.0, m=3.0, mp1=1.0, kc_p1_p2=1.0, tspan = (0.0, 10.0), alg=Tsit5()) - params = SS.canonicalize(SS.Tunable(), prob.p)[1] - u0 = [1., 0., -1.0, 0., 0., 1.0] - + params = [1., 1., 2., 1., 3., 1., 0.] @test_nowarn sum_test(params) @test_nowarn sum_test(params, vjp = SciMLSensitivity.MooncakeVJP()) - # ForwardDiff - ForwardDiff.gradient(sum_test, params) - value_and_gradient(sum_test, AutoEnzyme(), params) - - # Enzyme - dp = zeros(7) - Enzyme.autodiff(Reverse, sum_test, params, dp) - - # Mooncake - backend = AutoMooncake(; config=nothing) - prep = prepare_gradient(sum_test, backend, params) - Mooncake.gradient(sum_test, prep, backend, params) + @test_nowarn value_and_gradient(sum_test, AutoEnzyme(), params) + @test_nowarn value_and_gradient(sum_test, AutoMooncake(), params) end