Giter Site home page Giter Site logo

juliadiff / chainrulescore.jl Goto Github PK

View Code? Open in Web Editor NEW
249.0 12.0 60.0 8.7 MB

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.

License: Other

Julia 100.00%
automatic-differentiation julia hacktoberfest

chainrulescore.jl's People

Contributors

ablaom avatar ararslan avatar bencottier avatar bsnelling avatar chrisrackauckas avatar devmotion avatar gdalle avatar hyrodium avatar jrevels avatar jutho avatar keno avatar matbesancon avatar mcabbott avatar mcognetta avatar mzgubic avatar nickrobinson251 avatar niklasschmitz avatar nsajko avatar oschulz avatar oxinabox avatar pierre-haessig avatar piever avatar ranocha avatar sathvikbhagavan avatar sethaxen avatar simeonschaub avatar st-- avatar touchesir avatar willtebbutt avatar yingboma avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chainrulescore.jl's Issues

rrule_via_ad // frule_via_ad Calling back into AD from ChainRules

This was originally discussed in JuliaDiff/ChainRules.jl#12 (comment)
and in a few other places.

Basically often when defining a chainrule (frule, or rrule)
it would be nice to be able to say "Give me this chainrule for some of the function, and if there is not one predefined us AD to get it"
as part of your rule definition.
Right now this is not ppossible, except by hard coding a AD (e.g. Zygote) in.

Where as if we had a function that an AD system could basically overload,
then we could do that.

It would also provide a common API for all ADs that support it.

This would help e.g. with higher order functions like map, broadcast etc
JuliaDiff/ChainRules.jl#122

There is some fiddlyness involved, around making sure it is both overloadable multiple times and so the user can choise which AD, and that compiles away, but I think we can sort it all out.


@jessebett reminded me of this today

Special case derivative of non-holomorphic functions of type ℂ(^n)→ℝ

A common use case for non-holomorphic functions are norms and similar functions, that project a complex vector space onto the reals. These are also probably the most interesting for optimization problems. According to the Cauchy-Riemann equations, any such functions that are non-trivial have to be non-holomorphic, so we currently only have Wirtinger derivatives to describe these correctly.
In these cases, storing the full Wirtinger primal and conjugate is unnecessary, since by looking at the definitions, it is easy to see, that they each must be the complex conjugate of each other, so one would only need to store one of them. If used in arrays for example, this would save half the memory size otherwise required.
My proposal would be to introduce a singleton type, lets call it ConjugateOfWirtingerConjugatefor lack of a better name. This would be passed as primal to Wirtinger or WirtingerRule and signify that the Wirtinger primal is just the conjugate of the Wirtinger conjugate. This would need to be special cased in a couple of cases: For example if chained with a real derivative, ConjugateOfWirtingerConjugate can be preserved, though otherwise, this would need to fall back to a full Wirtinger derivative.
If there's consensus that something like this would be useful, I could prepare a PR. An alternative would be a special AbstractRule for these cases, which might make some things a bit cleaner, but it might be confusing to have two different types for derivatives of non-holomorphic functions. @jrevels seems to be the main instance on everything Wirtinger, what are your thoughts on this?

The Representation Problem: semantic types, vs computational types

Types get used in many ways, especially in Julia.
We can for this discussion break them down into two kinds.
Either a semantic use:for example the Int 5 in fill(x, 5)
or computational convience use, e.g. the Int 5 in exp2(5),
where it is just much faster to compute exp2 if your input is a Int.

In the semantic use, that Int has no derviative, it differential is DoesNotExist.
Or the singleton differential.

Where as in the computational convience case, the Int has a deriviative it's exp2(5)*log(2).

This also occurs for various kinda of (structured) sparsity.
Where you might do for computational convience f(u::AbstactMatrix) = istriu(u) ? f(UpperTrriangular(u)) : ...
Its important not to allow pertubations on the structual zeros of sparse types, when those have semantic meaning.
But when they don't then you should allow them to be perturbed.

idk if this affects us directly.
It might be something to document.

Delaration helpers

We have @scalar_rule which basically follows the API of DiffRules.

I think we should have more.

I think @rrule which follows the API of ZygoteRules.adjoint
and a matchign @ffrule

Maybe a @DoesNotExist that declares that something is nondifferentiable.

Also something like ZygoteRules.@which (probably @which_frule and @which rrule,
to locate the rules that would be called.

Make `extern` recursive?

it is not too hard to endup in a situtation where you have say
Thunk(()->Thunk(()->3)) ie., @thunk(@thunk(3))

One example (written in #30 style) is if you have

rrule(::typeof(f), x) = f(x), y->(NO_FIELDS, @thunk(3))
function rrule(::typeof(g), x)
    _, inner_pullback = rrule(f, x)
   function g_pullback(y)
        (NO_FIELDS, @thunk(g_pullback(y)[2])
   end
   return g(x), g_pullback
end

Especially with #30 where we (by nescity) throw a lot more @thunks around.

This is particularly annoying in test right now, and so probably in real use too.


I see two ways to resolve this:

1. make extern(::Thunk) recursive

this is just changing

@inline extern(x::Thunk) = x.f()

to

@inline extern(x::Thunk) = extern(x.f())

This doesn't help for (Casted(Zero()) or Casted(@thunk(3)) etc
but Casted could have the same change I guess,
or it could just or away #10

2. make all externing recursive

We would change current definitions of extern into
extern1
then change extern to call extern1 until the result is no longer a AbstractDifferential.
(either via recursion, or via a loop)

3. Make `Thunk(()->Thunk

Way harder than it looks.
Similarly does not apply to things other than thunks


Not sure on the implications of these on inlining,
I know inlining hates recursion.

Here is some test code I prepared earlier.

    @testset "Thunk" begin
        @test extern(Thunk(()->3)) == 3 == extern(@thunk(3))
        @test extern(Thunk(()->Thunk(()->3))) == 3 == extern(@thunk(@thunk(3)))
        @test @thunk(3) isa Thunk
    end

I think the first one is a good place to start.
And will likely PR that and rebase it into #30

AbstractArray or Composite?

Background

This issue does stuff with FillArrays, and got be thinking about a problem we've known about for a while but not successfully resolved.

Say that you have a Fill x of length N and you call getindex on it:

x[1]

How should the differential dx w.r.t x associated with this operation be represented? Some options:

  1. (bad) a dense array with a single non-zero element.
  2. (better) a fancy 1-hot array.
  3. (great) a Composite{Fill{eltype(x)}}.

Option 1 is bad because it's obviously O(N)

Option 2 appears to be performant, but only until you consider adding it to another differential. Say that you also got x[2] at some point in your programme. At some point you'll have to accumulate the differential for x[1] and x[2], whose sum is clearly not also a 1-hot array. Consequently, we've lost performance, and essentially regressed to O(N) if you start adding lots of 1-hot arrays together.

Option 3 is great, because when adding two Composite{Fill{eltype(x)}} you get another one in O(1)-time and memory.

Upshot

As presented above, what needs to happen for Fills is quite obvious imho. What is less obvious is how this generalises. There's been a discussion going on for a while about the correct way to represent the differential of an AbstractArray.

It seems that the correct thing to do with an Array is represent it as another Array, and this seems to be because an Array is already completely general in that its elements can take any value you like.

What is less clear it what to do for structured arrays. Take, for example, Diagonal. Should we allow its adjoint to be represented by any other AbstractArray, or should be we using a Composite? My feeling is that this observation regarding Fills provides evidence in favour of the latter, and that we need to figure out ways to handle resulting problems that arise (e.g. say that a Composite representing a Diagonal was propagated into the rrule for matrix-matrix multiply. We would need to know what to do with it)

Jacobians

Sometimes you actually do want the jacobian of a function. If this function is a vector->vector function, the Jacobian is a matrix. It's not immediately obvious how to define the jacobian for, say, a function that accepts a struct and returns a struct. It would be nice to figure out some sensible conventions here.

accumulate!(::Real, ...)

What is the intended behaviour when the first argument to accumulate is something for which materialize! isn't defined? For example, if we are working with a Real number.

docs for basic definitions

there should be a section in the docs briefly summarising what each of:

  • rrule
  • frule
  • Rule
  • Differential

If I have it right:

rrule/frule takes a function and a point,
and returns the value of that function at that point,
and a tuple of Ruless,
rules take a sensitivity point of the outputs / inputs (respectively?).
and return a Differential which like a value (and may be a value),
representing that sensitivity.

Make it easier to track down errors that occur in a thunk

Consider:

julia> extern(foo())
ERROR: ArgumentError: invalid rational: zero(Int64)//zero(Int64)
Stacktrace:
 [1] __throw_rational_argerror(::Type) at ./rational.jl:19
 [2] Rational at ./rational.jl:14 [inlined]
 [3] Rational at ./rational.jl:21 [inlined]
 [4] //(::Int64, ::Int64) at ./rational.jl:44
 [5] (::var"#13#14")() at /Users/oxinabox/.julia/packages/ChainRulesCore/LC8Sv/src/differentials.jl:243
 [6] (::Thunk{var"#13#14"})() at /Users/oxinabox/.julia/packages/ChainRulesCore/LC8Sv/src/differentials.jl:250
 [7] extern(::Thunk{var"#13#14"}) at /Users/oxinabox/.julia/packages/ChainRulesCore/LC8Sv/src/differentials.jl:251
 [8] top-level scope at REPL[9]:1

There was an error thrown within the thunk.
but the stacktrace tells me nothing about how to find the thunk.

It doesn't tell me wher the Thunk was, nor what it contained.

There are a few things we can investigate:

  1. (somehow) make it so the line numbers for the thunk's inner function in [5] reflect where the @thunk was rather than its contents
  2. use source.__line__ and source.__file__ to actually include that in the functions name
  3. include the expression wrapped in the functions name
  4. Have a way to disable thiunking and make it eager. I think this can be done by having disable_thunking()=@eval macro thunk(x) x end and matching enable_thunk() to restore original definition

The Extensibility Problem, for propagator closures

Here's a thing that isn't currently possible and is, I believe, something that we might actually want to care about. Consider the pullback for AbstractMatrix multiplication:

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    function times_pullback(Ȳ)
        return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
    end
    return A * B, times_pullback
end

Provided that is itself an AbstractMatrix, for which * with other matrices will be correctly defined assuming getindex is correctly defined, something correct will happen even if it's slow.

Now consider the case that is a NamedTuple, possibly because Y is some non-Matrix AbstractMatrix. Now what happens? The above breaks: * isn't defined for NamedTuples, nor is it possible to extend times_pullback to handle from outside the original rrule definition. One's only recourse is to add a completely new definition of pullback for * with AbstractTypeofA and AbstractTypeofB one expects to see with, which itself presupposes a method of pullback(typeof(*), ::AbstractTypeofA, ::AbstractTypeofB) doesn't already exist, in which case no option is available but to modify the existing method.

Phrased differently, the current design requires each rrule must be implemented to handle every possible type of that it might ever see. This is clearly an unreasonable requirement because Julia permits the creation of new types, and the input types to pullback to do not uniquely specify the type of . Assuming this unreasonably requirement unmet, we are left with two options when the above is encountered:

  1. Modify the original rrule to handle the new type encountered (the bad-ness of this is assumed self-evident)
  2. Write a new rrule specialised to different types of A and B. This really isn't great from a code re-use perspective. * is not a pathological case because the forwards-pass is quite straightforward, but other cases are worse.

This problem appears to manifest itself in cases where the forwards-pass is perfectly good for multiple types, but the reverse-pass requires care. For example Diagonal * Matrix: the forwards-pass and data required on the reverse-pass is no different than Matrix * Matrix, but the reverse-pass implementation is necessarily quite different.

This lack of extensibility is a direct consequence of the (value, back) = pullback(...) design choice that ChainRules / Zygote make. Nabla made a slightly different design choice in which the forwards- and reverse- bits of a pullback were separate functions, so you could extend things. This design doesn't share the extensibility issue that the ChainRules / Zygote style presents, but equally doesn't immediately enable the same sharing of state on the forwards- and reverse-passes.

One possible resolution would be to adopt the separated forwards- and reverse- passes chosen in Nabla, and allow an arbitrary communication object to be shared between the forwards- and reverse- passes.

function forward(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
    return A * B, (signature=(typeof(*), typeof(A), typeof(B)), A=A, B=B)
end

function pullback(::NamedTuple{(:signature, :A, :B), (Tuple{typeof(*), AbstractMatrix, AbstractMatrix}, AbstractMatrix, AbstractMatrix)}, Ȳ::AbstractMatrix)
    # do stuff
end
function pullback(<NamedTuple stuff>, Ȳ::NamedTuple{<some field names>})
    # do other stuff
end

A forward call would therefore not return a closure, but rather whatever intermediate data is deemed by the implementer to be important for the reverse-pass. pullback is then called with the appropriate signature to evaluate the adjoint. This interface is somewhat more verbose than the closure-based interface due to the need to copy-paste the signature all over the place, although this may be alleviated with some careful metaprogramming tooling. I would anticipate that we would also see an improvement in stack-trace readability, since we would get direct calls to a pullback function, with the types of forwards arguments placed prominently.

In summary, this change buys us the ability to extend reverse-pass behaviour using multiple dispatch, and hence code-reuse, at the expense of increased verbosity and the need to explicitly specify the data from the forwards-pass that may be required on the reverse.

This isn't something that needs resolving immediately, but I feel it should be given some consideration so that we can at least be aware that this is an issue we're choosing to ignore if nothing is done about it.

Side note: this appears analogous to be similar to the expression problem, which you can consult Stefan's JuliaCon talk on. Specifically, that we can't define new methods of back for existing pullbacks is (I think) analogous to not being able to define new methods that extend the functionality of existing types.

extern(::Casted) seems ill-defined. (Remove Casted?)

extern(cast(1)) returns 1. This is unfortunate behaviour as we loose all of the e.g. size information that (arguably) should be associated with the casted thing. For example the rrule for sum currently returns a casted 1 to represent an array of ones, which would have been the same size as the input. Consequently, without reference to the input you don't know what the appropriate size is for a particular Casted object.

Is there a particular reason that this route has been taken as opposed to, say, returning AbstractFills (from FillArrays.jl) which would be just as efficient in all cases that I can think of.

Some thoughts on Wirtinger

This comes up in discussion on Zygote
FluxML/Zygote.jl#142
FluxML/Zygote.jl#291

Mike/Zygote doesn't like Wirtinger.
Which I personally can get behind, since it is well outside the math I use.
And anything to do with them is a maintaince issue because
I think of the people the contribute to this repo only @simeonschaub and @jrevels properly get them,
(both in terms of the how, and the why)

Zygote never wants to see a Wirtinger differential result.
I think though that might mean also that Zygote wants to treat nonholomorphic functions as nondifferentiable?

Would it make sense to define

function extern(w::Wirtinger)
    iszero(extern(wirtinger_conjugate(w))) || throw(DomainError("Not holomorphic"))
    return extern(wirtinger_primal(w))
end 

Or the equiv inside Zygote.

Could we have a seperate rrule/frule
for Wirtinger that is allowed to return Wirtinger
and normal use of rrule and frule don't.
Then general_rrule and general_frule would be fallback to rrule and frule if there is no special overload for the Wirtinger case defined?

*, One, extern, and accumulate + friends

I almost named this Getting rid of * and One, in line with most of our other open issues but, alas, that wasn't an appropriate title. This issue is more of a commentary on the value of properly deciding what operations make sense for Differentials all of the time, what operations only make sense some of the time, and what operations never make sense.

Operations on Differentials

It is currently the case that * is defined between any two Differentials. It's not actually clear that this should always be the case. Quite possibly we should consider only defining * in certain cases e.g. for scalars. Similarly, the role of One is unclear in general, and there's some confusion between + and accumulate. The aim of this issue is to resolve these issues, and improve our collective understanding of what the things in the package actually (should) do.

First consider the things that you need to be able to do with Differentials to be able to perform AD:

  1. Add Differentials during the accumulation phase of reverse-mode. So we definitely need addition (accumulation in Zygote-speak) to work, always.
  2. Apply linear maps to differentials. AD essentially takes in your original programme and returns a linear programme that accepts Differentials and returns some other Differentials.

These are the only two things necessary for ChainRules(Core) to provide for AD systems. With this in mind, we plough ahead.

+(differential, differential)

This should always be defined, and it's clear how to do it with all of our currently implemented Differential types, with the exception of One. More on that later.

Not necessary to define +(primal, differential) all of the time

Notably, it's not necessary to be able to add Differentials to their primal type. Although it's often possible to do this (and you need to be able to do it for objects that you want to gradient-descend on, for example), automatic differentiation does not require that you can.

For example, the differential of a Vector{Float64} will typically be represented as a Vector{Float64}, which we can clearly add. It could also be a Zero, or a Thunk. Now, it happens to be the case that we know how to add these to their primals, but what if we have the differential w.r.t. to the following struct:

struct Foo
    x
    y
    Foo(x) = new(x, x)
end

The differential will be represented as a Composite (see @oxinabox 's PR) with two fields, x and y. It's entirely unclear how to usefully add these two objects. Clearly the result needs to be of type Foo, but it's not possible to increment both fields of Foo.

To summarise,

  1. addition between a given differential and some primal type needn't be defined all of the time
  2. addition will be defined between differentials and primals most of the time. e.g. you should expect that you can add the differential of a Vector or Float64 to said Vector or Float64. Similarly, additional between primal and Zero is always defined (just don't modify the primal). But there exist some primal-differential pairs that don't admit a useful / sensible definition of +, and that is okay.

One

See here

In short, it doesn't make sense as a differential, but it might not be a totally useless construct.

Multiplication

Multiplication doesn't make sense between all differential types, but we are currently required to define it for all differential types. For example, what does it mean to multiply two composites? You could define it to be some notion of elementwise multiplication, but it's not clear what we would gain by doing that.

In the context of ChainRules(Core), multiplication is best thought of as a data-parametrised linear map that is well-defined for certain types of differentials (scalars, vectors, matrices, etc) but not others (e.g. Composites). This observation resolves the issue with Wirtingers whereby it's not at all clear how to multiply them together. We could remove the method with the massive error message entirely since there would no longer be any expectation on the part of the user that one should be able to multiply them together in an unambiguous manner.

As such, it clearly has value and I believe that we should

  • define it for Zero, because anything * Zero is Zero (currently implemented)
  • define it recursively for AbstractThunk (currently implemented)
  • not require it to be defined for Composite, because how would it be usefully defined?
  • remove the definition of * for Wirtingers entirely, as there's no longer any expectation that it should work.

Accumulate + friends

Given the centrality of +, it makes sense to re-visit accumulate at this point.

  • accumulate literally just implements +. We have +, so there's no need for accumulate.
  • accumulate! is the in-place version of +. This method lets you do A = A + B without allocating, where A and B are differentials. It might make sense for us to rename this add! for consistency with +. It's important to note, however, that it's not always possible to in-place add differentials, and this will only be well-defined some of the time. A more useful way of thinking about this functionality is as maybe_add_inplace.
  • Similar comment for store!. It only really makes sense some of the time, specifically when dealing with dense arrays, so it's not clear to me how much of a help this really is. Probably requires further though.

To my mind, the conclusion here is that we just remove accumulate, and think further about what we actually want to get out of accumulate! and store! in a separate issue. They're not top-priority, but there are definitely use-cases for them in big neural-network-y applications where you typically work handle a lot of StridedArray types.

Deploy docs

I think this jsut invovles adding a deploy key?

There are docstrings that should be more easily accessible. In the future we may want to do something else (#12) but this seems like the best start

Can we get rid of `Thunk`?

@jrevels: @willtebbutt and I were going through the Differentials to make sure we actually know what they are for.
And started to wonder if we need them.

Each one we get rid of simplifies things a lot,
especially when it comes to #16

I think we might be able to just have
Wirtinger, One, Zero, and DNE.

Wirtinger

I have only the barest understanding of what this is.
It effectively seems like a particularly convient way to deal with
deriviatives with respect complex number (In contrast to handling them as structs to (#4))
Probably useful.

DNE

Does not exist. Obviously useful.

One, Zero

Useful identities that are evaluated lazily, and can thus be removed from the chain efficiently.

Casted

It is kind of the generalization of One, and Zero.
(in that One() could also be written Casted(true) etc).
It lets us lazily delay computing a broadcast,
so that it can be fused later.
But I think in the short term we can simplify the code
by replacing say
Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy))
with
Rule((Δx, Δy) -> sum(Δx .*y) + sum(x .* Δy)
(from here
which for that particular case would even be identical in performance I think.
Since it does not end up returning any kind of lazy computation.
And later we can try getting back the lazy computation and broadcast fusing by returning broadcasted.

Getting rid of Casted would solve #10

Thunk

Thunk seemed really useful at first,
but I am not sure anymore that it actually does anything.

A thunk is basically wrapping a function returning Differentiable f(v) in a ()->f(v)
so as not to have to compute it yet.
But Any time you interact with it (e.g. via add or mul) it gets externed,
because if you don't do that you can get huge chains of thunks that call thunks,
and also because at the time you are called e.g. add you probably do actually want the value -- your not going to skip it and only use the other part.

And the using it inside a rule isn't actually making anything extra deferred until the backwards pass, since rules themselfs are deffered until backward pass.

E.g. lookinng at this rule
Rather than

function rrule(::typeof(inv), x::AbstractArray)
    Ω = inv(x)
    m = @thunk(-Ω')
    return Ω, Rule(ΔΩ -> m * ΔΩ * Ω')
end

we could just do

function rrule(::typeof(inv), x::AbstractArray)
    Ω = inv(x)
    return Ω, Rule(ΔΩ -> -Ω' * ΔΩ * Ω')
end

Which boils down to the same thing since it when the rule is invoked it gets externed anyway. by * becoming mul.

Even in the case of the derivative for multiple things, so you would have multiple rules referencing the thunk, it still doesn't change anything since thunks don't cache
(#7).
I recall @jrevels saying that they used to cache, so maybe still having them is a legacy of that time and we just didn't notice that they didn't do anything anymore.

What type should NO_FIELDS be?

NO_FIELDS should be Zero() not DNE()

Thought about it.
The derviative with respect to all the fields of which there is none,
is probably better decribed as Zero.
There is no relation between those (onexistant) fields and the function output.

And we don't want it to error when externed.

Support chunked frule

from #90 it seems that @YingboMa wants frule to be able to be called on a Vector of sensitivies for the same primal value,
and get a of sensitivities vector back,
but without broadcasting ? (presumably because that would also recompute the forwad primal)

I don't understand properly.
So this thread is get @YingboMa @shashi or @ChrisRackauckas to explain that.

This might need a redesign of frrule again similar to solving #74.
Maybe we went too far there, since broadcasting the pushforward would presumably solve that case.

Overload generation mode: support enumerating the rules

Right now, we don't provide any way to list out all rules.
A bunch of packages that use DiffRules assume that that is possible.
And then they take that list and use metaprogramming to overload some methods.

I am thinking we should maybe provide a legacy support function that does this.
That uses methods(frule) and then filters it for things that promise not to return nothing.

And we call it something that makes it clear that it is not advised to use this.
Since it prevents extensibility in other packages)
(but we can give advise to e.g. ensure the latest set of rules can be loaded via calling update_rules or something.)

Write docs on how to write rules to run different computations at different times

With #30
this might be a great example

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
    # Compute C (return value differently) to just calling `A/B`
    # so we have some parts precomputed 
    # that we need to use for computing the pullback

    Aᵀ, dA_pb = rrule(adjoint, A)
    Bᵀ, dB_pb = rrule(adjoint, B)
    Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ)
    C, dC_pb = rrule(adjoint, Cᵀ)
    
    function slash_pullback(Ȳ)
         # These two run no matter which partial you want
         # but they don't run if you only want the result (C)

        _, dC = dC_pb(Ȳ)
        _, dAᵀ, dBᵀ = dS_pb(dC)

        # Each of these only run if you want that particular partial
        ∂A = @thunk last(dA_pb(dAᵀ))
        ∂B = @thunk last(dA_pb(dBᵀ))

        (NO_FIELDS, ∂A, ∂B)
    end
    return C, slash_pullback
end

Invocation helpers

In the same sense of #44 we might want some invocation helpers.

Something like

function grad(f, args...)
    _, pullback = rrule(f, args...)
    partials = pullback(One)
    return extern.(partials[2:end])
end

Which is what I was using for testing stuff with Zygote.

Rules for functions / structs etc

The current design requires that rules have the form

function rrule(::typeof(foo), args...)
    ...
    return (rule_arg_1, rule_arg_2, ...)
end

I would like to propose to extend this to the following:

function rrule(foo::typeof(foo), args...)
    ...
    return (rule_foo, rule_arg_1, rule_arg_2, ...)
end

A lot of the time rule_foo will be a DNE, which is fine. However, if we're looking to implement rules for particular structs, then it may well be important to be able to represent the gradient w.r.t. the fields of the struct.

Dealing with non-array/scalar structured values

Consider the case of an eigenvalue decomposition from the eigen function, which produces both eigenvalues and eigenvectors in an Eigen object. Giles provides forward- and reverse-mode sensitivities for the decomposition, which depend on both the eigenvalues and vectors. That begs the question of how this should be expressed in ChainRules.

In a conversation with @jrevels, he said that his vision for this was to use named tuples in cases such as this, which involve structures other than arrays and scalars. However, we weren't able to come to a concrete conclusion on whether it makes more sense for a Rule for e.g. eigen to produce a named tuple upon application of the rule, or if frule/rrule should themselves should produce a named tuple of Rules.

It seems that the author of a rule can reuse computations for the eigenvalues and vectors if applying the rule yields a single named tuple. This would look something along the lines of (untested):

function frule(::typeof(eigen), X::AbstractMatrix)
    E = eigen(X)
    λ, U = E
    n = size(X, 1)
    R = Rule() do ΔX
        Y = U' * ΔX * U
        values = copy(diag(Y))
        @inbounds for j = 1:n, i = 1:n
            if i == j
                Y[i,j] = 0
            else
                Y[i,j] /= λ[j] - λ[i]
            end
        end
        vectors = U * Y
        (values=values, vectors=vectors)
    end
    return E, R
end

However, if frule/rrule returns a named tuple of rules, one can specify accessor functions much more conveniently, something like:

function frule(::typeof(eigvals), X::AbstractMatrix)
    E, R = frule(eigen, X)
    return E.values, R.values
end

To quote Jarrett directly, at what stage do we want the caller to know that they should deal with a named tuple?

Remove Rule (or maybe all AbstractRules) and treat functions as Rules

Lyndon White:ox: Yesterday at 10:18 PM

I am wondering if we will be able to get rid of all subtypes of AbstractRule.
Idk if the general overloads re indexing and getindex are not used and are kinda confusing.
And would be removed by #31 (comment)
1-arg rule is just a function now.
2-arg rule maybe can be done just by overloading accumulate? (Not sure onthis one,)
DNERule is just a function that always returns DNE
I think we could refactor things so the WirtingerRule goes away and we just have a function that returns a Wirtinger.
4 replies

Lyndon White:ox: 16 hours ago

We can remove them a few at a time. See what breaks

Simeon Schaub 2 hours ago

I kind of like having a WirtingerRule because that usually needs to be handled seperately since it doesn't implement all the arithmetic. Having a special type means this handling can be done at compile time, which I don't think can be done in a non-hacky way with just a function

Lyndon White:ox: 2 hours ago

Definitely we want the Wirtinger differential.
I am less sure about the WirtingerRule pullback (edited)

Simeon Schaub 1 hour ago

The case I'm thinking of is if you're building up a tape for reverse AD, where you would store only the rules in some kind of tree structure. But to materialize the results, it was important in my case to know, whether the storage type needs to handle Wirtinger derivatives

Use a `NoRule` sentinel type instead of `Nothing`?

Currently if we don't have a rule for something, we return nothing

I do think it is tempting to return nothing here... but I wonder if it is just more explicit to return a NoRule (/NotImplemented) sentinal type, like we already have DNE to distinguish the case when the derivative does not exist (#17). Thoughts?

Rules for unary functions, alway return tuple?

They currently return the rule, as opposed to a 1-tuple where the only element is the rule. This is inconsistent with n-ary functions, which return a tuple of rules, and makes writing code annoying. What would people's feelings towards changing this convention such that unary functions do in fact return a tuple of rules be?

add(::Real, ::Wirtinger) throws StackOverflowError

When trying to add a Wirtinger derivative to either Zero(), One(), a real number or DNE(), currently a StackOverflowError is thrown. It would probably make sense for real numbers to be handled as Wirtinger(x, Zero()), although one would still have to decide how to handle DNE().

Efficient complex differentiation by structured 2x2 matrices

For efficient complex differentiation, we need to express the following structured matrices:

holomorphic:

[a -b]
[b  a]

anti-holomorphic:

[a  b]
[b -a]

C->R:

[a b]
[0 0]

R->C:

[0 0]
[a b]

general:

[a c]
[b d]

Wirtinger derivative archives this by doing a basis transformation from x, y to z, z̄. However, that would introduce more FLOPs, since you need to transform them to x, y when multiplying a complex number. I.e. (x = z+z̄) and (y = i(z - z̄))

IMO, structured matrices are far more transparent than Wirtinger derivative, and they don't require a change of basis before multiplying with a complex number.

To implement structured matrices, we could do

struct Holomorphic{T,S}
    a::T
    b::S
end

Base.:*(::Holomorphic, ::Complex) = ...

Define a very basic AD, for testing and for demonstration

Was discussing that it would be good
if ChainRules.jl defined a very dumb naive AD,
like some kind of cutdown Flux.Tracker or Nabla tape based thing.

for purposes of demonstrating how ChainRules would be used with an AD,
and for testing the rules.

This could be defined in the source (perhaps in a submodule),
or in the tests.

Make good error message for can't construct only show in DebugMode (for performance reasons)

Resolving #78 sorted out the majority of the problem, but it didn't get everything (at least not to a degree that is acceptable imho)

Running the same example on master now yields:

using ChainRulesCore, BenchmarkTools

struct Foo
    x::Float64
end

foo = Foo(0.5)
Δfoo = Composite{typeof(foo)}(; x=0.4)

@benchmark $foo + $Δfoo
julia> @benchmark $foo + $Δfoo
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     24.149 ns (0.00% GC)
  median time:      24.955 ns (0.00% GC)
  mean time:        25.422 ns (0.00% GC)
  maximum time:     95.542 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     996

(version 1.3.1)

This is obviously way better, but compared to

julia> @benchmark Foo($foo.x + $Δfoo.x)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     0.020 ns (0.00% GC)
  median time:      0.031 ns (0.00% GC)
  mean time:        0.029 ns (0.00% GC)
  maximum time:     0.072 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

it's really quite bad.

If I instead run

using ChainRulesCore: construct, elementwise_add, backing
@benchmark construct($(typeof(foo)), elementwise_add(backing($foo), backing($Δfoo)))
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.693 ns (0.00% GC)
  median time:      1.701 ns (0.00% GC)
  mean time:        1.758 ns (0.00% GC)
  maximum time:     19.254 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

I get something much closer to what would be ideal. The discrepancy between these two code snippet is the error handling here.

Perhaps we could compile away this overhead with a generated function somehow? My understanding is that it's there to suggest to a user that maybe they've got fields in the Composite that aren't present in the thing they're adding it to, and should maybe fix that -- presumably we could catch that kind of problem at compile-time in a generated function, and spit out either

  • code in which is guaranteed not to error in the particular ways we want to avoid, or
  • code which is guaranteed to error whenever it is run because we know it would otherwise try to access fields of the primal that don't exist.

Any thoughts @oxinabox @nickrobinson251 ?

Remove Cassette and define promotion rules

This is the "proper" way to solve #5,
since it will let us avoid all the recompilation stuff.
And it makes the code less "magic"

From @ararslan in JuliaDiff/ChainRules.jl#37 (comment)

According to @jrevels, the use of Cassette in this package is purely an implementation detail and could be replaced with regular ol' multiple dispatch. The reason why Jarrett didn't do that to begin with is it requires defining a fair number of methods to resolve ambiguities between the various differential types, rule types, etc. I think going that route will be a cleaner solution overall than replacing the dependency on Cassette with on one IRTools.

See old PR JuliaDiff/ChainRules.jl#38

I am not 100% convinced this is required, particularly if we can just @inline our overdubs and make the compiler optimize them out of existance.

But it is something we should think about.

Rename `Rule` ?

"rule" efers to so many things in this package already.

Possible alternatives:

  • GenericRule
  • JRule
  • LRule
  • FuncRule

A trait-based system to handle Wirtinger derivatives

While writing ForwardDiff2.jl I have encountered a problem with nested differentiation with the Wirtinger derivative.

Say if I have

x = Dual(Dual(1+1im, 1+0im), 1+0im)

how can I make hypot(value(x)) dispatch to the right place if there is

@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))

Neither Dual <: Number or Dual <: Real would work.

https://github.com/JuliaDiff/ChainRules.jl/blob/146d031235b28265772b8c0b98ed54b812c39316/src/rulesets/Base/base.jl#L59-L60

Maybe we need to have a trait-based system to handle Wirtinger derivatives?

Seperating frule and pushforward prevents efficient solutions (fuse pushforward)

When implementing the frules, I realized that the implementation of frule doesn't allow for standard optimizations that are seen in forward rule implementations. The reason is because the way that forward mode works is that it propagates the derivative along step by step. A good primer on all of this is this set of notes:

https://mitmath.github.io/18337/lecture9/autodiff_dimensions

Essentially what we are trying to do with ChainRules.jl is allow the user to describe how to calculate f(x) and f'(x)v, the primal and the jvp. Currently the formulation is:

function frule(::typeof(foo), args; kwargs...)
    ...
    return y, pushforward
end

where pushforward is pushforward(dargs). However, given that discussion of forward mode differentiation, one can see that this runs contrary to how it is actually calculated. Here's two examples of it

Example 1: Implementing ForwardDiff over frules

As described in the notes, the dual number way of computing forward mode starts by seeding dual numbers. In standard ForwardDiff usage, these seeds are all unique basis vectors, like is shown in the DiffEq documentation for how to AD through the solver manually:

https://docs.juliadiffeq.org/v6.8/analysis/sensitivity/#Examples-using-ForwardDiff.jl-1

But as mentioned in the notes, what this is really doing is seeding the duals in the basis vector e_i directions, so then the jvp is computing J*e_1,J*e_2,J*e_3 as separate vectors, giving a representation of the full Jacobian. If you do get the whole Jacobian, then you can do J*v of course, and this is what the current `frule would allow:

function frule(::typeof(f), x; kwargs...)
    dual_x = seed_duals(x) # seeds along the basis vector directions
   # this gives a dual number of length(x) dimensions
    dual_y = f(x)
    y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is Jacobian
    function pushforward(dx)
       dy*dx
    end
    return y, pushforward
end

However, this shows that there is a more efficient way to calculate y,dy*dx though, since if we know the dx at the start, we can just seed the dual numbers along the direction of of dx, which changes the number of dual dimensions from length(x) to 1:

function frule(::typeof(f), x, dx; kwargs...)
    dual_y = f(dual.(x,dx)) # 2 dimensional number
    y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
    return y, dy 
end

This changes it from an O(n) computation to O(1)!

Example 2: Implementation of Forward Sensitivity Analysis for ODEs

Now here's a bit more concrete example for the user side. For ODEs, you want to look at:

u' = f(u,p,t)

and you want to know how u(t) changes w.r.t. p. So take the d/dp of both sides of the ODE and by the chain rule you get (swap integrals, assume nice properties)

d/dt du/dp = df/du du/dp + df/dp

calling S = du/dp, this is just

S' = (df/du)*S + df/dp

So you get another ODE that gives you the derivatives of the solution of the original ODE w.r.t parameters. This is the continuous pushforward rule! Now the difficulty is that you need to be able to calculate (df/du)(t) which requires that you know u(t). Now in theory you could calculate u(t) a continuous solution beforehand by solving the previous ODE and storing it, but that's not the good way to do it. The way you do it is just realize that, if you solve the ODE:

u' = f(u,p,t)
S' = (df/du)*S + df/dp

together, then you always know u since it's the first part of the equation! So magic happens and this is very efficient.

That's almost there. What sensitivities are we pushing forward though? You can seed the sensitivities from S=0 and the output S = du/dp, but that's not satisfying. What if you wanted to know du/d(u0) and du/dp? Since concrete_solve(p,u0,odeprob,solver,...) is a function of both p and u0, we want the derivative of the ODE's solution with respect to the p and the u0.

It turns out from simple math that all you have to do is set S = du0! So then, in "composed frule" notation, you'd do the following:

function frule(::typeof(concrete_solve),p,dp,u0,du0,odeprob)
  S = du0
  _prob = build_bigger_ode(odeprob,[u0,S])
  sol = solve(_prob,solver)
  y,dy = split_solution(sol)
  y,dy.*dp # weigh by the direction vector!
end

Right now, this can't really be expressed.

API

Actually having those arguments might be difficult, so maybe it's easier to write as:

function frule(::typeof(f), x, dx; kwargs...)
    function pushforward(dx)
        dual_y = f(dual.(x,dx)) # 2 dimensional number
        y,dy = value(dual_y),partials_as_matrix(dual_y) # y is primal, dy is f'(x)*dx
        return y, dy 
    end
end

Anyways, the exact API is an interesting question, but whatever it is, the computation should have the x and the dx at the same time.

Move integration CI tests to Cirrus or AppVeyor

If they were running on a seperate CI service we could still see when they fail easily,
but we would also see easily ChainRulesCore itself hasn't failed.

Also because ChainRules itself takes a while to test we can run these in parallel.

Primal + Composite perf

The following is a bit sad:

using ChainRulesCore, BenchmarkTools

struct Foo
    x::Float64
end

foo = Foo(0.5)
Δfoo = Composite{typeof(foo)}(; x=0.4)

@benchmark $foo + $Δfoo

BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  6
  --------------
  minimum time:     1.044 μs (0.00% GC)
  median time:      1.071 μs (0.00% GC)
  mean time:        1.079 μs (0.00% GC)
  maximum time:     3.762 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

Most of the time this isn't a problem. Some of the time (currently for me) it is 🙁 .

Evaluating rules is unreasonably slow on first call

Currently, evaluating rules can be pretty slow (and in some cases extremely slow) on the first call. It seems to be due to JIT compiling the anonymous functions, the use of which is a central design point of ChainRules.

Here's a basic example, where it takes 2.6 seconds(!!!) to evaluate the derivative for svd. Note that the ChainRules version is a simple port of the Nabla version, so the underlying code that does the computation is nearly identical.

julia> using ChainRules, LinearAlgebra

julia> F, dX = rrule(svd, randn(4, 4));

julia> nt = (U=F.U, S=F.S, V=F.V);

julia> @time dX(nt)
  2.644154 seconds (11.32 M allocations: 589.599 MiB, 6.32% gc time)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

julia> @time dX(nt)
  0.017788 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

Compare this to Nabla:

julia> using Nabla, LinearAlgebra

julia> X = randn(4, 4); F = svd(X); nt = (U=F.U, S=F.S, V=F.V);

julia> @time ∇(svd, Arg{1}, (), F, nt, X)
  0.631797 seconds (2.37 M allocations: 114.680 MiB, 3.77% gc time)
4×4 Array{Float64,2}:
 -1.09959    1.16123    2.27612    -0.97398 
  1.42299   -0.17832   -2.52324     1.48229 
 -3.32982   -0.746237   1.1226      2.98457 
  0.346265   0.402947  -0.0502055   0.661208

julia> @time dX(nt)
  0.006206 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

We should find some way(s) to mitigate this so that AD systems which switch to using ChainRules underneath won't take an enormous performance hit by doing so.

Explicit Rules for Higher Order "Adjoints"

Can the API for adding an adjoint rule allow for explicitly specifying the rule for a higher order adjoint?

e.g. D(sin,x) = v -> v * cos(x) but I also know that D(sin,x; n=2) = (v1,v2) -> v2*v1*(- sin(x))

Differentiating with respect to a function

I think this may be covered by #4, but I just want to add this as a use case: differentiating

const A = randn(10,10)
function f(y)
    N = length(y)
    g(x) = sum(abs2, A*x-y)
    x = nlsolve(g, zero(N)) # with appropriate jacobian information, etc
    sum(x.*y)
end
f(randn(10))

The goal would be to define a rule for nlsolve that would compute the derivative of x using the jacobian information and the derivative of g wrt y.

CachingThunks

It came up in discussion of JuliaDiff/ChainRules.jl#21
that it would be handing to have an object which when evaluted caches its value
(probably a Differentiable, but potentially not?)
that will let us handle things that have computations that might be shared by some of the multiple returned differntiables.

A possibly alternative is to just give this behavour to all Thunks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.