|
| 1 | +# Implicit differentiation |
| 2 | + |
| 3 | +## Background |
| 4 | + |
| 5 | +Differentiating implicit functions efficiently using the implicit function theorem has many applications including: |
| 6 | +- Nonlinear partial differential equation constrained optimization |
| 7 | +- Differentiable optimization layers in deep learning (aka deep declarative networks) |
| 8 | +- Differentiable fixed point iteration algorithms for optimal transport (e.g. the Sinkhorn methods) |
| 9 | +- Gradient-based bi-level and robust optimization (aka anti-optimization) |
| 10 | +- Multi-parameteric programming (aka optimization sensitivity analysis) |
| 11 | + |
| 12 | +For more on implicit differentation, refer to the last part of the [_Understanding automatic differentiation (in Julia)_](https://www.youtube.com/watch?v=UqymrMG-Qi4) video on YouTube and the [_Efficient and modular implicit differentiation_](https://arxiv.org/abs/2105.15183) manuscript for an introduction to the methods implemented here. |
| 13 | + |
| 14 | +## Relationship to [`ImplicitDifferentiation.jl`](https://github.com/gdalle/ImplicitDifferentiation.jl) |
| 15 | + |
| 16 | +[`ImplicitDifferentiation.jl`](https://github.com/gdalle/ImplicitDifferentiation.jl) is an attempt to simplify the implementation in `Nonconvex` making it more lightweight and better documented. For instance, the [documentation of `ImplicitDifferentiation`](https://gdalle.github.io/ImplicitDifferentiation.jl/) presents a number of examples of implicit functions all of which can be defined and used using `Nonconvex`. |
| 17 | + |
| 18 | +## Explicit parameters |
| 19 | + |
| 20 | +There are 4 components to any implicit function: |
| 21 | +1. The parameters `p` |
| 22 | +2. The variables `x` |
| 23 | +3. The residual `f(p, x)` which is used to define `x(p)` as the `x` which satisfies `f(p, x) == 0` for a given value `p` |
| 24 | +4. The algorithm used to evaluate `x(p)` satisfying the condition `f(p, x) == 0` |
| 25 | + |
| 26 | +In order to define a differentiable implicit function using `Nonconvex`, you have to specify the "forward" algorithm which finds `x(p)`. For instance, consider the following example: |
| 27 | +```julia |
| 28 | +using SparseArrays, NLsolve, Zygote, Nonconvex |
| 29 | + |
| 30 | +N = 10 |
| 31 | +A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) |
| 32 | +p0 = randn(N) |
| 33 | + |
| 34 | +f(p, x) = A * x + 0.1 * x.^2 - p |
| 35 | +function forward(p) |
| 36 | + # Solving nonlinear system of equations |
| 37 | + sol = nlsolve(x -> f(p, x), zeros(N), method = :anderson, m = 10) |
| 38 | + # Return the zero found (ignore the second returned value for now) |
| 39 | + return sol.zero, nothing |
| 40 | +end |
| 41 | +``` |
| 42 | +`forward` above solves for `x` in the nonlinear system of equations `f(p, x) == 0` given the value of `p`. In this case, the residual function is the same as the function `f(p, x)` used in the forward pass. One can then use the 2 functions `forward` and `f` to define an implicit function using: |
| 43 | +```julia |
| 44 | +imf = ImplicitFunction(forward, f) |
| 45 | +xstar = imf(p0) |
| 46 | +``` |
| 47 | +where `imf(p0)` solves the nonlinear system for `p = p0` and returns the zero `xstar` of the nonlinear system. This function can now be part of any arbitrary Julia function differentiated by Zygote, e.g. it can be part of an objective function in an optimization problem using gradient-based optimization: |
| 48 | +```julia |
| 49 | +obj(p) = sum(imf(p)) |
| 50 | +g = Zygote.gradient(obj, p0)[1] |
| 51 | +``` |
| 52 | + |
| 53 | +In the implicit function's adjoint rule definition, the partial Jacobian `∂f/∂x` is used according to the implicit function theorem. Often this Jacobian or a good approximation of it might be a by-product of the `forward` function. For example when the `forward` function does an optimization using a BFGS-based approximation of the Hessian of the Lagrangian function, the final BFGS approximation can be a good approximation of `∂f/∂x` where the residual `f` is the gradient of the Lagrangian function wrt `x`. In those cases, this Jacobian by-product can be returned as the second argument from `forward` instead of `nothing`. |
| 54 | + |
| 55 | +## Implicit parameters |
| 56 | + |
| 57 | +In some cases, it may be more convenient to avoid having to specify `p` as an explicit argument in `forward` and `f`. The following is also valid to use and will give correct gradients with respect to `p`: |
| 58 | +```julia |
| 59 | +function obj(p) |
| 60 | + N = length(p) |
| 61 | + f(x) = A * x + 0.1 * x.^2 - p |
| 62 | + function forward() |
| 63 | + # Solving nonlinear system of equations |
| 64 | + sol = nlsolve(f, zeros(N), method = :anderson, m = 10) |
| 65 | + # Return the zero found (ignore the second returned value for now) |
| 66 | + return sol.zero, nothing |
| 67 | + end |
| 68 | + imf = ImplicitFunction(forward, f) |
| 69 | + return sum(imf()) |
| 70 | +end |
| 71 | +g = Zygote.gradient(obj, p0)[1] |
| 72 | +``` |
| 73 | +Notice that `p` was not an explicit argument to `f` or `forward` in the above example and that the implicit function is called using `imf()`. Using some explicit parameters and some implicit parameters is also supported. |
| 74 | + |
| 75 | +## Matrix-free linear solver in the adjoint |
| 76 | + |
| 77 | +In the adjoint definition of implicit functions, a linear system: |
| 78 | +```julia |
| 79 | +(df/dy) * x = v |
| 80 | +``` |
| 81 | +is solved to find the adjoint vector. To solve the system using a matrix-free iterative solver (GMRES by default) that avoids constructing the Jacobian `df/dy`, you can set the `matrixfree` keyword argument to `true` (default is `false`). |
| 82 | + |
| 83 | +When set to `true`, the entrie Jacobian matrix is formed and the linear system is solved using LU factorization. |
| 84 | + |
| 85 | +## Arbitrary data structures |
| 86 | + |
| 87 | +Both `p` and `x` above can be arbitrary data structures, not just arrays of numbers. |
| 88 | + |
| 89 | +## Tolerance |
| 90 | + |
| 91 | +The implicit function theorem assumes that some conditions `f(p, x) == 0` is satisfied. In practice, this will only be approximately satisfied. When this condition is violated, the gradient reported by the implicit function theorem cannot be trusted since its assumption is violated. The maximum tolerance allowed to "accept" the solution `x(p)` and the gradient is given by the keyword argument `tol` (default value is `1e-5`). When the norm of the residual function `f(p, x)` is greater than this tolerance, `NaN`s are returned for the gradient instead of the value computed via the implicit function theorem. If additionally, the keyword argument `error_on_tol_violation` is set to `true` (default value is `false`), an error is thrown if the norm of the residual exceeds the specified tolerance `tol`. |
0 commit comments