Giter Site home page Giter Site logo

Comments (11)

oxinabox avatar oxinabox commented on June 30, 2024 1

Can we close this now that we have ProjectTo going on?
It seems like the answer is if something subtypes abstract array we would in general prefer to handle it (if we handle it) as a AbstractArray, and fix it in the ProjectTo.
Which i think makes sense because it is the natural differential, which people want to write rules for.
and this facilitates calling functions on it that do expect arrays

cc @mcabbott @mzgubic

from chainrulescore.jl.

oxinabox avatar oxinabox commented on June 30, 2024

Both are acceptable.
As long as it can be added to itself, and the primal.
Ideally it would be addable to all differentails that can be added to the primal,
but that can always be solve via added a new + method, or more likely: one one of the differential types is reasonable (in this case Composite) and the other is not (in this case the subtype of AbstractArray).

Though we do run into the expression problem (#53), since the code most likely expects the thing to be AbstractMatrix,
but the solution to that is to solve #53 not anything else.

from chainrulescore.jl.

willtebbutt avatar willtebbutt commented on June 30, 2024

As pointed out by @MikeInnes here and on slack, allowing for a choice of representation of the differential of a Fill produces problems here. In particular, consider something like

x = Fill(1.0, (2, 2))
y = Composite{typeof(x)}(; value=2.0)
x + y

Whether the behaviour of + here should be

sum(x) + y

or

Fill(getindex_value(x) + y.value)

depends on whether you're treating x as a differential or as a primal. i.e. you probably want the former behaviour when executing the reverse-pass of AD (i.e. differential + differential), and the latter when doing stuff with the results of AD (i.e. primal + differential).

I think there are a couple of solutions here:

  • distinguish between + and something similar to Zygote.accum. + would define primal + differential stuff, and accum differential + differential stuff.
  • Insist that the type of an object tells us enough to know how to define + unambiguously. This would entail requiring that whenever we encounter a situation like the above, the primal + differential definition would apply, and it's the job of the rule-writer to ensure that no frule or rrule returns a Fill to represent a differential.

I could be convinced of either approach I think, but I instinctively prefer the latter as the former feels like a patch -- better to disambiguate between fundamentally different objects than to make + mean different things in different contexts imho. But a case could of course be made the other way.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on June 30, 2024

I'm really sceptical that a sensible design can be had in which we have both composite and abstractarray gradients.

Will's concern in FluxML/Zygote.jl#445 was that using a FillArray as the gradient was hacky (I'm taking the core issue to be that gradients are inconsistent with the equivalent dense array). But what I've tried to illustrate is that this is inevitable with any version of this, because if you follow through on Composite it ends up behaving the same way as FillArray; if you go that route you may as well just use FillArray and not pretend.

If on the other hand that tradeoff is unnacceptable, you have to use dense (or perhaps one-hot) arrays, which again leads to throwing out Composite.

This will be true for any array type. If you mix composite and array gradients you're going to get a somewhat arbitrary mix of different kinds derivatives back out.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on June 30, 2024

At some point you'll have to accumulate the differential for x[1] and x[2], whose sum is clearly not also a 1-hot array.

Also, while I appreciate that this is meant to be more general, exactly the same O(N^2) behaviour affects regular arrays. It might be that the true solution is simply to fix this and not treat Fills as special at all (i.e. making arrays gradients mutable and updating them in place, which fixes the complexity).

from chainrulescore.jl.

willtebbutt avatar willtebbutt commented on June 30, 2024

Also, while I appreciate that this is meant to be more general, exactly the same O(N^2) behaviour affects regular arrays. It might be that the true solution is simply to fix this and not treat Fills as special at all (i.e. making arrays gradients mutable and updating them in place, which fixes the complexity).

The point here is more that in the dense case this is a necessary consequence of the immutability of differentials (assuming you've not got some fancy compiler optimisation going on). Conversely, this really shouldn't be the case for Fills with immutable differentials.

I'm really sceptical that a sensible design can be had which we have both composite and abstractarray gradients.

I think I probably agree. That said, we can definitely have Reals and Arrays representing the differentials / adjoints of Reals and Arrays respectively.

because if you follow through on Composite it ends up behaving the same way as FillArray; if you go that route you may as well just use FillArray and not pretend.

I'm sorry but I'm still not the inevitability of this. Assuming that you do indeed always represent the differential of a Fill as a Composite, the numbers that you get out when adding to a primal are quite different.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on June 30, 2024

It depends on what the exact proposal is; the discussion on the Zygote issue indicated a mixing of array and composite adjoints (i.e. using composites for efficiency where possible but keeping array semantics broadly), so my comments are mainly aimed at attempts to preserve current semantics with better efficiency. I'm not certain but it seems like this issue is more about throwing out array semantics entirely, and defining the derivative only wrt the fill value, for which the earlier examples indeed don't apply.

This case is a bit different because gradients of FillArrays themselves are now unambiguous, but there are still ways in which their ambiguity can leak out. Contrived example (since you mentioned upstream rrules in the original issue):

@adjoint function (a * b)
  y = a * b
  y = reduce(==, y) ? Fill(mean(y), size(y)) : y
  y, c̄ -> (c̄'b, a'c̄)
end

gradient(x -> (W*x)[1], x)

If W*x produces a fill, it will get a Composite gradient back. It obviously has to be able to deal with this, which means it has to interpret the gradient as an array somehow, with an inherent ambiguity as to what elements of y are responsible for the gradient. You don't have to choose to interpret the Composite as a FillArray here – you could use a random array with the same sum – but it raises the same issues as if you did. In this case it's actually the gradients of two innocent bystanders – W and x – that end up being ambiguous.

Even if it were easy to buy Fill out of all the abstract array gradients, it's hard to see a way around this, since in generic code pretty much any adjoint can produce a Fill at any time, and those generic adjoints then all need to be updated to know how to convert Composite{Fill} to something like an array.

from chainrulescore.jl.

willtebbutt avatar willtebbutt commented on June 30, 2024

Firstly, apologies for the slow response.

It depends on what the exact proposal is; the discussion on the Zygote issue indicated a mixing of array and composite adjoints (i.e. using composites for efficiency where possible but keeping array semantics broadly), so my comments are mainly aimed at attempts to preserve current semantics with better efficiency. I'm not certain but it seems like this issue is more about throwing out array semantics entirely, and defining the derivative only wrt the fill value, for which the earlier examples indeed don't apply.

I think this is what this issue is more about now -- your comments on the ambiguities inherent in mixing composite and array types for differentials seem pretty compelling.

If W*x produces a fill, it will get a Composite gradient back. It obviously has to be able to deal with this, which means it has to interpret the gradient as an array somehow, with an inherent ambiguity as to what elements of y are responsible for the gradient. You don't have to choose to interpret the Composite as a FillArray here – you could use a random array with the same sum – but it raises the same issues as if you did. In this case it's actually the gradients of two innocent bystanders – W and x – that end up being ambiguous.

I don't entirely follow - specifically it's unclear to me why its valid to choose any array with the same sum if you receive a composite.

On the topic of using a Fill as the differential of a Fill, what if I have a Fill whose eltype is ComplicatedType for which division-by-real isn't defined, and hence the mean isn't defined? In this case I can't see how you could represent the differential as a Fill.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on June 30, 2024

Assuming you buy the xs[i] ≡ mean(xs) argument, then xs[i] is also equivalent to any weighted mean of xs (where the weights sum to one). With random weights you get a random gradient with the same sum; and we can reduce all Fill operations to xs[i].

We should be clear, though, that the mean(xs) example is more of a tool for intuition in a particular case than something fundamental. The basic issue is that if we have only a Composite gradient and we need an element-wise gradient, the Composite only tells us the sum of the adjoints of all elements, and nothing about how individual elements contributed; which means we either have to throw an error here or pick a distribution somewhat arbitrarily (and we might motivate this by analogy to a specific operation, or by what's convenient performance-wise).

So the fact that sum(xs)/length(xs) is a weird operation for, say, Dates is not that important; we could probably contrive something that produces the equivalent gradient, but even if not, the idea that the Composite adjoint represents the sum of the adjoints of the array elements (whether represented as milliseconds or composites) is still well defined, and we still face the same issues when it comes to array operations.

from chainrulescore.jl.

willtebbutt avatar willtebbutt commented on June 30, 2024

Assuming you buy the xs[i] ≡ mean(xs) argument, then xs[i] is also equivalent to any weighted mean of xs (where the weights sum to one). With random weights you get a random gradient with the same sum; and we can reduce all Fill operations to xs[i].

Okay, I think I follow.

I think my view on this is that we should never need an element-wise gradient for a Fill. Certainly, it means that you can't just expect rrule defined for AbstractArrays to necessarily work.

Taking your example, the rule-writer has the information required to know unambiguously what the gradient should be by virtue of the fact that mean was called internally inside * -- all of the gradients should be the same in the Fill case. If you had not written a rule for this, the AD would have been able to figure that out. I know I've taken your example quite literally, but it's unclear to me that this demonstrates ambiguity leakage.

from chainrulescore.jl.

MikeInnes avatar MikeInnes commented on June 30, 2024

The core point is really that switching to Composite gradients changes the gradients we get for W and x, in some cases, compared to the current situation; this is what (it seems) we want to avoid. Ok, sure, you can define away the ambiguity by looking at the implementation of the primal -- but if we're redefining what counts as a correct gradient we may as well make it convenient. Does the adjoint author have to keep an eye on the package providing the primal, and change the adjoint if mean(y) gets changed to y[1]?

More broadly, I'm not imagining that people will actually write code like this, but Fills can be emitted from generic code perhaps as a result of interactions between multiple packages. If an innocuous-looking adjoint happens to produce a Fill, the gradient will error out. To fix that we have to effectively audit the implementation (potentially a lot of code written by other people) in order to define the gradient of the Fill and then have dispatch in the adjoint to handle that case (and later make sure it stays in sync with the package). At the end of this we are still changing gradients from their dense array equivalents, and in a way that doesn't have any semantic guarantees as package implementations change.

from chainrulescore.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.