Comments (7)
Thanks @ChrisRackauckas for reporting. Using ParallelStencil in inversion frameworks based on adjoint rules is definitively something that has potential (here we derived the adjoint for multi-physics 3D problem - "not automatically").
However, we are not familiar with Zygote-based VJP, nor with ReverseDiffVJPs. In order to understand what may be modified in ParallelStencil to support AD and adjoint calculations with your tools, could you be more explicit about what you think is currently missing?
Also, can you tell us what are the requirements on the function to be passed to Zygote (or similar) so we can make ParallelStencil fit these requirements?
from parallelstencil.jl.
Yeah so I was playing around with a few things this week that were relevant. The key here is:
- ReverseDiff works on mutation, can work element-wise, but only on CPU
- Zygote works on CPU+GPU, cannot do mutation and generally is only good on higher level operations.
So one thing you could do is define https://github.com/JuliaDiff/ChainRules.jl rules over the stencil computation for Zygote. An example of doing this can be found in Tullio https://github.com/mcabbott/Tullio.jl. That doesn't solve the mutation problem. So I think an interesting solution here could be to get Enzyme compatibility (@wsmoses) https://github.com/wsmoses/Enzyme.jl. For the GPU code here, it might already have the ability to differentiate the kernel. And this is compatible with mutation.
The key here is that if you get the derivative of that stencil computation working, then the adjoint functionality of DifferentialEquations.jl only needs that the u'=f(u,p,t) that right hand side f
is revese-mode compatible in order to generate the full adjoint, so these examples would work once that is in place. I added the ability yesterday to use EnzymeVJP
for Enzyme-generated vector Jacobian product kernels, so that could possibly be the solution here but I haven't gone back to test this case yet.
from parallelstencil.jl.
Thanks @ChrisRackauckas for your suggestions and further insights. I am still wondering which approach is generally followed under the hood by the workflows you describe to retrieve automatically the objects needed to perform inversions using adjoint to compute gradients. Independently of the solution method, one would need an expression for the transposed Jacobian (or a function for it to do it in a matrix-free fashion). After discussing with @greuber, deriving the expression for the transpose Jacobian for the adjoint is the challenge. One can either do that analytically, what we did here (see Appendix 6), ending up with an infinite dimensional system of equations that one can discretize and solve with our preferred solver.
One question is would AD deliver something similar (transposed Jacobian), and if so, in what form ?
Maybe we could use the nonlinear diffusion example (from Appendix 6) as MWE to try out those things and see if your tools combined to ParallelStencil could retrieve the equations needed to solve the matrix-free system in a similar fashion we do it in the nonlinear 1D diffusion code here or maybe even better. (If so, I can quickly rewrite the 1D Matlab example in Julia).
from parallelstencil.jl.
Independently of the solution method, one would need an expression for the transposed Jacobian (or a function for it to do it in a matrix-free fashion).
Reverse-mode AD is matrix-free transposed Jacobian calculations. I would recommend taking a look at https://github.com/mitmath/18337 if you're curious about that, specifically https://mitmath.github.io/18337/lecture10/estimation_identification .
With Zygote, the vector-transposed Jacobian product of f'(x)'*v
is calculated via:
y,back = Zygote.pullback(f,x)
back(v)
The issue is that fails on your stencil functions. I have been playing around with Enzyme on other projects this last week, and I think that might be the right one for your case. Since you're generating GPU code, it should be statically compliable to allow it to run its AD passes on the LLVM IR. The vector-transposed Jacobian product in that case is done via:
Enzyme.autodiff(Duplicated(y, v),
Duplicated(x, λ)) do _y,_x
f!(_y,_x)
nothing
end
for a non-allocating mutating function f!(y,x)
to calculate f'(x)'*v
. This should be a lot faster, if the parallel primitives are able to be handled successfully.
One can either do that analytically, what we did here (see Appendix 6), ending up with an infinite dimensional system of equations that one can discretize and solve with our preferred solver.
The final solution is a mixture. Discrete adjoint sensitivity analysis does the entire adjoint via reverse-mode AD, but that is costly memory-wise. Continuous adjoint sensitivity analysis uses the infinite dimensional system so you can solve forward and reverse with a preferred solver, but then the right-hand side of the adjoint equation is automatically generated via these reverse-mode AD tools to get fast vector-transposed Jacobian products. That's what the sensitivity methods are doing (https://diffeq.sciml.ai/stable/analysis/sensitivity/), and if you watch the video on the SciML adjoint system you'll see the trade-offs between all of the different choices of vector-transposed Jacobian product (vjp) and their mixtures with generated adjoints (https://www.youtube.com/watch?v=XRJ-rtP2fVE).
All of this is done automatically though. In the code above, when I did Zygote.gradient(loss,[1.0])
, it sees the solve
in there and it automatically knows to solve forward, and then generate the adjoint code to solve backwards. Where it fails is in the generated rhs expression, it fails inside of the vjp (using ReverseDiff for the vjp) on diffusion3D_step!
because ReverseDiff.jl has some type restrictions. Zygote.pullback is not amenable to mutation, so EnzymeVJP is probably the thing to try next.
Hopefully that explains what all is going on and how it's pulling in two different levels of reverse-mode AD to chain together gradient calculations.
from parallelstencil.jl.
@wsmoses on:
const USE_GPU = false
using ParallelStencil, OrdinaryDiffEq
using ParallelStencil.FiniteDifferences3D
@static if USE_GPU
@init_parallel_stencil(CUDA, Float64, 3);
else
@init_parallel_stencil(Threads, Float64, 3);
end
@parallel function diffusion3D_step!(T2, T, Ci, lam, dx, dy, dz)
@inn(T2) = lam*@inn(Ci)*(@d2_xi(T)/dx^2 + @d2_yi(T)/dy^2 + @d2_zi(T)/dz^2);
return
end
function diffusion3D(lam,alg)
# Physics
cp_min = 1.0; # Minimal heat capacity
lx, ly, lz = 10.0, 10.0, 10.0; # Length of domain in dimensions x, y and z.
# Numerics
nx, ny, nz = 16, 16, 16; # Number of gridpoints dimensions x, y and z.
nt = 100; # Number of time steps
dx = lx/(nx-1); # Space step in x-dimension
dy = ly/(ny-1); # Space step in y-dimension
dz = lz/(nz-1); # Space step in z-dimension
# Array initializations
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
Ci = @zeros(nx, ny, nz);
# Initial conditions (heat capacity and temperature with two Gaussian anomalies each)
Ci .= 1.0./( cp_min .+ Data.Array([5*exp(-(((ix-1)*dx-lx/1.5))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) +
5*exp(-(((ix-1)*dx-lx/3.0))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)]) )
T .= Data.Array([100*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/3.0)/2)^2) +
50*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/1.5)/2)^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)])
T2 .= T; # Assign also T2 to get correct boundary conditions.
dt = min(dx^2,dy^2,dz^2)*cp_min/8.1; # Time step for the 3D Heat diffusion
function f(du,u,p,t)
@show t
@parallel diffusion3D_step!(du, u, Ci, p[1], dx, dy, dz);
end
prob = ODEProblem(f, T, (0.0,nt*dt), lam)
sol = solve(prob, alg, save_everystep = false, save_start = false, sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()))
end
sol = diffusion3D([1.0],ROCK2())
using ForwardDiff, Zygote, DiffEqSensitivity
function loss(p)
sum(diffusion3D(p,ROCK2()))
end
ForwardDiff.gradient(loss,[1.0])
Zygote.gradient(loss,[1.0])
I'm getting:
TypeError: in Type, in parameter, expected Type, got a value of type DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}
Val(x::Function) at essentials.jl:693
autodiff at Enzyme.jl:60 [inlined]
_vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP, Bool}, Array{Float64, 3}, ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, DiffEqSensitivity.CheckpointSolution{ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, Vector{Tuple{Float64, Float64}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::EnzymeVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing) at derivative_wrappers.jl:471
vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, Enzym...
Can Enzyme not handle captured variables or is that fixed on some branch?
from parallelstencil.jl.
Can Enzyme not handle captured variables or is that fixed on some branch?
That should be fixed nowadays. (Except that this doesn't work on the GPU).
from parallelstencil.jl.
Except that this doesn't work on the GPU
What would be needed for it to run on GPU?
from parallelstencil.jl.
Related Issues (20)
- AMDGPU v0.5.0 compat HOT 1
- Add device_sync
- sync issues on AMDGPU backend
- Make CellArrays mutable HOT 4
- finite volume method HOT 3
- [JuliaCon/proceedings-review] @parallel keyword argument `loopopt` deprecated? HOT 1
- ParallelStencil on 1.10 HOT 6
- [JuliaCon/proceedings-review] DOI of paper by Besard et al. HOT 2
- [JuliaCon/proceedings-review] Community guidelines HOT 1
- [JuliaCon/proceedings-review] Performance metrics HOT 4
- Type unstable Data.Number HOT 2
- GPU memory management issue when running multi-GPU code HOT 10
- Add support for Polyester's `@batch` HOT 20
- Generalize loopopt
- Create and update GPU unit tests
- Thread (CPU) Float32/Float64 performance comparison on miniapp acoustic2D HOT 12
- Example for init_global_grid_usage HOT 3
- How to implement custom finite differencing operators HOT 8
- CUDA Crash with julia 1.9.0 HOT 8
- Non cartesian gather! HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from parallelstencil.jl.