When implementing the frules, I realized that the implementation of frule
doesn't allow for standard optimizations that are seen in forward rule implementations. The reason is because the way that forward mode works is that it propagates the derivative along step by step. A good primer on all of this is this set of notes:
https://mitmath.github.io/18337/lecture9/autodiff_dimensions
Essentially what we are trying to do with ChainRules.jl is allow the user to describe how to calculate f(x)
and f'(x)v
, the primal and the jvp. Currently the formulation is:
function frule(::typeof(foo), args; kwargs...)
...
return y, pushforward
end
where pushforward
is pushforward(dargs)
. However, given that discussion of forward mode differentiation, one can see that this runs contrary to how it is actually calculated. Here's two examples of it
Example 1: Implementing ForwardDiff over frules
As described in the notes, the dual number way of computing forward mode starts by seeding dual numbers. In standard ForwardDiff usage, these seeds are all unique basis vectors, like is shown in the DiffEq documentation for how to AD through the solver manually:
https://docs.juliadiffeq.org/v6.8/analysis/sensitivity/#Examples-using-ForwardDiff.jl-1
But as mentioned in the notes, what this is really doing is seeding the duals in the basis vector e_i
directions, so then the jvp is computing J*e_1,J*e_2,J*e_3
as separate vectors, giving a representation of the full Jacobian. If you do get the whole Jacobian, then you can do J*v
of course, and this is what the current `frule would allow:
function frule(::typeof(f), x; kwargs...)
dual_x = seed_duals(x) # seeds along the basis vector directions
# this gives a dual number of length(x) dimensions
dual_y = f(x)
y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is Jacobian
function pushforward(dx)
dy*dx
end
return y, pushforward
end
However, this shows that there is a more efficient way to calculate y,dy*dx
though, since if we know the dx
at the start, we can just seed the dual numbers along the direction of of dx
, which changes the number of dual dimensions from length(x)
to 1:
function frule(::typeof(f), x, dx; kwargs...)
dual_y = f(dual.(x,dx)) # 2 dimensional number
y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
return y, dy
end
This changes it from an O(n) computation to O(1)!
Example 2: Implementation of Forward Sensitivity Analysis for ODEs
Now here's a bit more concrete example for the user side. For ODEs, you want to look at:
and you want to know how u(t) changes w.r.t. p. So take the d/dp of both sides of the ODE and by the chain rule you get (swap integrals, assume nice properties)
d/dt du/dp = df/du du/dp + df/dp
calling S = du/dp
, this is just
So you get another ODE that gives you the derivatives of the solution of the original ODE w.r.t parameters. This is the continuous pushforward rule! Now the difficulty is that you need to be able to calculate (df/du)(t)
which requires that you know u(t)
. Now in theory you could calculate u(t)
a continuous solution beforehand by solving the previous ODE and storing it, but that's not the good way to do it. The way you do it is just realize that, if you solve the ODE:
u' = f(u,p,t)
S' = (df/du)*S + df/dp
together, then you always know u
since it's the first part of the equation! So magic happens and this is very efficient.
That's almost there. What sensitivities are we pushing forward though? You can seed the sensitivities from S=0
and the output S = du/dp
, but that's not satisfying. What if you wanted to know du/d(u0)
and du/dp
? Since concrete_solve(p,u0,odeprob,solver,...)
is a function of both p
and u0
, we want the derivative of the ODE's solution with respect to the p
and the u0
.
It turns out from simple math that all you have to do is set S = du0
! So then, in "composed frule" notation, you'd do the following:
function frule(::typeof(concrete_solve),p,dp,u0,du0,odeprob)
S = du0
_prob = build_bigger_ode(odeprob,[u0,S])
sol = solve(_prob,solver)
y,dy = split_solution(sol)
y,dy.*dp # weigh by the direction vector!
end
Right now, this can't really be expressed.
API
Actually having those arguments might be difficult, so maybe it's easier to write as:
function frule(::typeof(f), x, dx; kwargs...)
function pushforward(dx)
dual_y = f(dual.(x,dx)) # 2 dimensional number
y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
return y, dy
end
end
Anyways, the exact API is an interesting question, but whatever it is, the computation should have the x
and the dx
at the same time.