Giter Site home page Giter Site logo

Comments (7)

luraess avatar luraess commented on May 30, 2024

Thanks @ChrisRackauckas for reporting. Using ParallelStencil in inversion frameworks based on adjoint rules is definitively something that has potential (here we derived the adjoint for multi-physics 3D problem - "not automatically").
However, we are not familiar with Zygote-based VJP, nor with ReverseDiffVJPs. In order to understand what may be modified in ParallelStencil to support AD and adjoint calculations with your tools, could you be more explicit about what you think is currently missing?
Also, can you tell us what are the requirements on the function to be passed to Zygote (or similar) so we can make ParallelStencil fit these requirements?

from parallelstencil.jl.

ChrisRackauckas avatar ChrisRackauckas commented on May 30, 2024

Yeah so I was playing around with a few things this week that were relevant. The key here is:

  1. ReverseDiff works on mutation, can work element-wise, but only on CPU
  2. Zygote works on CPU+GPU, cannot do mutation and generally is only good on higher level operations.

So one thing you could do is define https://github.com/JuliaDiff/ChainRules.jl rules over the stencil computation for Zygote. An example of doing this can be found in Tullio https://github.com/mcabbott/Tullio.jl. That doesn't solve the mutation problem. So I think an interesting solution here could be to get Enzyme compatibility (@wsmoses) https://github.com/wsmoses/Enzyme.jl. For the GPU code here, it might already have the ability to differentiate the kernel. And this is compatible with mutation.

The key here is that if you get the derivative of that stencil computation working, then the adjoint functionality of DifferentialEquations.jl only needs that the u'=f(u,p,t) that right hand side f is revese-mode compatible in order to generate the full adjoint, so these examples would work once that is in place. I added the ability yesterday to use EnzymeVJP for Enzyme-generated vector Jacobian product kernels, so that could possibly be the solution here but I haven't gone back to test this case yet.

from parallelstencil.jl.

luraess avatar luraess commented on May 30, 2024

Thanks @ChrisRackauckas for your suggestions and further insights. I am still wondering which approach is generally followed under the hood by the workflows you describe to retrieve automatically the objects needed to perform inversions using adjoint to compute gradients. Independently of the solution method, one would need an expression for the transposed Jacobian (or a function for it to do it in a matrix-free fashion). After discussing with @greuber, deriving the expression for the transpose Jacobian for the adjoint is the challenge. One can either do that analytically, what we did here (see Appendix 6), ending up with an infinite dimensional system of equations that one can discretize and solve with our preferred solver.

One question is would AD deliver something similar (transposed Jacobian), and if so, in what form ?

Maybe we could use the nonlinear diffusion example (from Appendix 6) as MWE to try out those things and see if your tools combined to ParallelStencil could retrieve the equations needed to solve the matrix-free system in a similar fashion we do it in the nonlinear 1D diffusion code here or maybe even better. (If so, I can quickly rewrite the 1D Matlab example in Julia).

from parallelstencil.jl.

ChrisRackauckas avatar ChrisRackauckas commented on May 30, 2024

Independently of the solution method, one would need an expression for the transposed Jacobian (or a function for it to do it in a matrix-free fashion).

Reverse-mode AD is matrix-free transposed Jacobian calculations. I would recommend taking a look at https://github.com/mitmath/18337 if you're curious about that, specifically https://mitmath.github.io/18337/lecture10/estimation_identification .

With Zygote, the vector-transposed Jacobian product of f'(x)'*v is calculated via:

y,back = Zygote.pullback(f,x)
back(v)

The issue is that fails on your stencil functions. I have been playing around with Enzyme on other projects this last week, and I think that might be the right one for your case. Since you're generating GPU code, it should be statically compliable to allow it to run its AD passes on the LLVM IR. The vector-transposed Jacobian product in that case is done via:

Enzyme.autodiff(Duplicated(y, v),
                        Duplicated(x, λ)) do _y,_x
  f!(_y,_x)
  nothing
end

for a non-allocating mutating function f!(y,x) to calculate f'(x)'*v. This should be a lot faster, if the parallel primitives are able to be handled successfully.

One can either do that analytically, what we did here (see Appendix 6), ending up with an infinite dimensional system of equations that one can discretize and solve with our preferred solver.

The final solution is a mixture. Discrete adjoint sensitivity analysis does the entire adjoint via reverse-mode AD, but that is costly memory-wise. Continuous adjoint sensitivity analysis uses the infinite dimensional system so you can solve forward and reverse with a preferred solver, but then the right-hand side of the adjoint equation is automatically generated via these reverse-mode AD tools to get fast vector-transposed Jacobian products. That's what the sensitivity methods are doing (https://diffeq.sciml.ai/stable/analysis/sensitivity/), and if you watch the video on the SciML adjoint system you'll see the trade-offs between all of the different choices of vector-transposed Jacobian product (vjp) and their mixtures with generated adjoints (https://www.youtube.com/watch?v=XRJ-rtP2fVE).

All of this is done automatically though. In the code above, when I did Zygote.gradient(loss,[1.0]), it sees the solve in there and it automatically knows to solve forward, and then generate the adjoint code to solve backwards. Where it fails is in the generated rhs expression, it fails inside of the vjp (using ReverseDiff for the vjp) on diffusion3D_step! because ReverseDiff.jl has some type restrictions. Zygote.pullback is not amenable to mutation, so EnzymeVJP is probably the thing to try next.

Hopefully that explains what all is going on and how it's pulling in two different levels of reverse-mode AD to chain together gradient calculations.

from parallelstencil.jl.

ChrisRackauckas avatar ChrisRackauckas commented on May 30, 2024

@wsmoses on:

const USE_GPU = false
using ParallelStencil, OrdinaryDiffEq
using ParallelStencil.FiniteDifferences3D
@static if USE_GPU
    @init_parallel_stencil(CUDA, Float64, 3);
else
    @init_parallel_stencil(Threads, Float64, 3);
end

@parallel function diffusion3D_step!(T2, T, Ci, lam, dx, dy, dz)
    @inn(T2) = lam*@inn(Ci)*(@d2_xi(T)/dx^2 + @d2_yi(T)/dy^2 + @d2_zi(T)/dz^2);
    return
end

function diffusion3D(lam,alg)
    # Physics
    cp_min     = 1.0;                                        # Minimal heat capacity
    lx, ly, lz = 10.0, 10.0, 10.0;                           # Length of domain in dimensions x, y and z.

    # Numerics
    nx, ny, nz = 16, 16, 16;                              # Number of gridpoints dimensions x, y and z.
    nt         = 100;                                        # Number of time steps
    dx         = lx/(nx-1);                                  # Space step in x-dimension
    dy         = ly/(ny-1);                                  # Space step in y-dimension
    dz         = lz/(nz-1);                                  # Space step in z-dimension

    # Array initializations
    T   = @zeros(nx, ny, nz);
    T2  = @zeros(nx, ny, nz);
    Ci  = @zeros(nx, ny, nz);

    # Initial conditions (heat capacity and temperature with two Gaussian anomalies each)
    Ci .= 1.0./( cp_min .+ Data.Array([5*exp(-(((ix-1)*dx-lx/1.5))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) +
                                       5*exp(-(((ix-1)*dx-lx/3.0))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)]) )
    T  .= Data.Array([100*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/3.0)/2)^2) +
                       50*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/1.5)/2)^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)])
    T2 .= T;                                                 # Assign also T2 to get correct boundary conditions.

    dt = min(dx^2,dy^2,dz^2)*cp_min/8.1;                 # Time step for the 3D Heat diffusion
    function f(du,u,p,t)
        @show t
        @parallel diffusion3D_step!(du, u, Ci, p[1], dx, dy, dz);
    end
    prob = ODEProblem(f, T, (0.0,nt*dt), lam)
    sol = solve(prob, alg, save_everystep = false, save_start = false, sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()))
end

sol = diffusion3D([1.0],ROCK2())

using ForwardDiff, Zygote, DiffEqSensitivity
function loss(p)
    sum(diffusion3D(p,ROCK2()))
end
ForwardDiff.gradient(loss,[1.0])
Zygote.gradient(loss,[1.0])

I'm getting:

TypeError: in Type, in parameter, expected Type, got a value of type DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}
Val(x::Function) at essentials.jl:693
autodiff at Enzyme.jl:60 [inlined]
_vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP, Bool}, Array{Float64, 3}, ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, DiffEqSensitivity.CheckpointSolution{ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, Vector{Tuple{Float64, Float64}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::EnzymeVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing) at derivative_wrappers.jl:471
vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, Enzym...

Can Enzyme not handle captured variables or is that fixed on some branch?

from parallelstencil.jl.

vchuravy avatar vchuravy commented on May 30, 2024

Can Enzyme not handle captured variables or is that fixed on some branch?

That should be fixed nowadays. (Except that this doesn't work on the GPU).

from parallelstencil.jl.

luraess avatar luraess commented on May 30, 2024

Except that this doesn't work on the GPU

What would be needed for it to run on GPU?

from parallelstencil.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.