@@ -14,7 +14,7 @@ Before getting to the explanation, here's some code to start with. We will follo
1414
1515``` @example hamiltonian_cp
1616using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
17- ComponentArrays, Optimization, OptimizationOptimisers, IterTools
17+ ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
1818
1919t = range(0.0f0, 1.0f0; length = 1024)
2020π_32 = Float32(π)
@@ -23,37 +23,33 @@ p_t = reshape(cos.(2π_32 * t), 1, :)
2323dqdt = 2π_32 .* p_t
2424dpdt = -2π_32 .* q_t
2525
26- data = vcat (q_t, p_t)
27- target = vcat (dqdt, dpdt)
26+ data = cat (q_t, p_t; dims = 1 )
27+ target = cat (dqdt, dpdt; dims = 1 )
2828B = 256
29- NEPOCHS = 100
30- dataloader = ncycle(
31- ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
32- selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
33- for i in 1:(size(data, 2) ÷ B)),
34- NEPOCHS)
35-
36- hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
29+ NEPOCHS = 500
30+ dataloader = DataLoader((data, target); batchsize = B)
31+
32+ hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
3733ps, st = Lux.setup(Xoshiro(0), hnn)
3834ps_c = ps |> ComponentArray
3935
4036opt = OptimizationOptimisers.Adam(0.01f0)
4137
42- function loss_function(ps, data, target)
38+ function loss_function(ps, databatch)
39+ data, target = databatch
4340 pred, st_ = hnn(data, ps, st)
44- return mean(abs2, pred .- target), pred
41+ return mean(abs2, pred .- target)
4542end
4643
47- function callback(ps , loss, pred )
44+ function callback(state , loss)
4845 println("[Hamiltonian NN] Loss: ", loss)
4946 return false
5047end
5148
52- opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target),
53- Optimization.AutoForwardDiff())
54- opt_prob = OptimizationProblem(opt_func, ps_c)
49+ opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff())
50+ opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
5551
56- res = Optimization.solve(opt_prob, opt, dataloader ; callback)
52+ res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS )
5753
5854ps_trained = res.u
5955
@@ -75,7 +71,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we
7571
7672``` @example hamiltonian
7773using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
78- ComponentArrays, Optimization, OptimizationOptimisers, IterTools
74+ ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
7975
8076t = range(0.0f0, 1.0f0; length = 1024)
8177π_32 = Float32(π)
@@ -87,40 +83,37 @@ dpdt = -2π_32 .* q_t
8783data = cat(q_t, p_t; dims = 1)
8884target = cat(dqdt, dpdt; dims = 1)
8985B = 256
90- NEPOCHS = 100
91- dataloader = ncycle(
92- ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
93- selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
94- for i in 1:(size(data, 2) ÷ B)),
95- NEPOCHS)
86+ NEPOCHS = 500
87+ dataloader = DataLoader((data, target); batchsize = B)
9688```
9789
9890### Training the HamiltonianNN
9991
10092We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
10193
10294``` @example hamiltonian
103- hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64 , 1)); autodiff = AutoZygote())
95+ hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028 , 1)); autodiff = AutoZygote())
10496ps, st = Lux.setup(Xoshiro(0), hnn)
10597ps_c = ps |> ComponentArray
98+ hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)
10699
107- opt = OptimizationOptimisers.Adam(0.01f0 )
100+ opt = OptimizationOptimisers.Adam(0.005f0 )
108101
109- function loss_function(ps, data, target)
110- pred, st_ = hnn(data, ps, st)
111- return mean(abs2, pred .- target), pred
102+ function loss_function(ps, databatch)
103+ (data, target) = databatch
104+ pred = hnn_stateful(data, ps)
105+ return mean(abs2, pred .- target)
112106end
113107
114- function callback(ps , loss, pred )
108+ function callback(state , loss)
115109 println("[Hamiltonian NN] Loss: ", loss)
116110 return false
117111end
118112
119- opt_func = OptimizationFunction(
120- (ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote())
121- opt_prob = OptimizationProblem(opt_func, ps_c)
113+ opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
114+ opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
122115
123- res = solve(opt_prob, opt, dataloader ; callback)
116+ res = Optimization. solve(opt_prob, opt; callback, epochs = NEPOCHS )
124117
125118ps_trained = res.u
126119```
0 commit comments