Comments (11)
Can we close this now that we have ProjectTo going on?
It seems like the answer is if something subtypes abstract array we would in general prefer to handle it (if we handle it) as a AbstractArray, and fix it in the ProjectTo.
Which i think makes sense because it is the natural differential, which people want to write rules for.
and this facilitates calling functions on it that do expect arrays
from chainrulescore.jl.
Both are acceptable.
As long as it can be added to itself, and the primal.
Ideally it would be addable to all differentails that can be added to the primal,
but that can always be solve via added a new +
method, or more likely: one one of the differential types is reasonable (in this case Composite
) and the other is not (in this case the subtype of AbstractArray
).
Though we do run into the expression problem (#53), since the code most likely expects the thing to be AbstractMatrix
,
but the solution to that is to solve #53 not anything else.
from chainrulescore.jl.
As pointed out by @MikeInnes here and on slack, allowing for a choice of representation of the differential of a Fill
produces problems here. In particular, consider something like
x = Fill(1.0, (2, 2))
y = Composite{typeof(x)}(; value=2.0)
x + y
Whether the behaviour of +
here should be
sum(x) + y
or
Fill(getindex_value(x) + y.value)
depends on whether you're treating x
as a differential or as a primal. i.e. you probably want the former behaviour when executing the reverse-pass of AD (i.e. differential + differential), and the latter when doing stuff with the results of AD (i.e. primal + differential).
I think there are a couple of solutions here:
- distinguish between
+
and something similar toZygote.accum
.+
would define primal + differential stuff, andaccum
differential + differential stuff. - Insist that the type of an object tells us enough to know how to define
+
unambiguously. This would entail requiring that whenever we encounter a situation like the above, the primal + differential definition would apply, and it's the job of the rule-writer to ensure that nofrule
orrrule
returns aFill
to represent a differential.
I could be convinced of either approach I think, but I instinctively prefer the latter as the former feels like a patch -- better to disambiguate between fundamentally different objects than to make +
mean different things in different contexts imho. But a case could of course be made the other way.
from chainrulescore.jl.
I'm really sceptical that a sensible design can be had in which we have both composite and abstractarray gradients.
Will's concern in FluxML/Zygote.jl#445 was that using a FillArray as the gradient was hacky (I'm taking the core issue to be that gradients are inconsistent with the equivalent dense array). But what I've tried to illustrate is that this is inevitable with any version of this, because if you follow through on Composite
it ends up behaving the same way as FillArray
; if you go that route you may as well just use FillArray
and not pretend.
If on the other hand that tradeoff is unnacceptable, you have to use dense (or perhaps one-hot) arrays, which again leads to throwing out Composite
.
This will be true for any array type. If you mix composite and array gradients you're going to get a somewhat arbitrary mix of different kinds derivatives back out.
from chainrulescore.jl.
At some point you'll have to accumulate the differential for
x[1]
andx[2]
, whose sum is clearly not also a 1-hot array.
Also, while I appreciate that this is meant to be more general, exactly the same O(N^2) behaviour affects regular arrays. It might be that the true solution is simply to fix this and not treat Fills as special at all (i.e. making arrays gradients mutable and updating them in place, which fixes the complexity).
from chainrulescore.jl.
Also, while I appreciate that this is meant to be more general, exactly the same O(N^2) behaviour affects regular arrays. It might be that the true solution is simply to fix this and not treat Fills as special at all (i.e. making arrays gradients mutable and updating them in place, which fixes the complexity).
The point here is more that in the dense case this is a necessary consequence of the immutability of differentials (assuming you've not got some fancy compiler optimisation going on). Conversely, this really shouldn't be the case for Fill
s with immutable differentials.
I'm really sceptical that a sensible design can be had which we have both composite and abstractarray gradients.
I think I probably agree. That said, we can definitely have Real
s and Array
s representing the differentials / adjoints of Real
s and Array
s respectively.
because if you follow through on Composite it ends up behaving the same way as FillArray; if you go that route you may as well just use FillArray and not pretend.
I'm sorry but I'm still not the inevitability of this. Assuming that you do indeed always represent the differential of a Fill
as a Composite
, the numbers that you get out when adding to a primal are quite different.
from chainrulescore.jl.
It depends on what the exact proposal is; the discussion on the Zygote issue indicated a mixing of array and composite adjoints (i.e. using composites for efficiency where possible but keeping array semantics broadly), so my comments are mainly aimed at attempts to preserve current semantics with better efficiency. I'm not certain but it seems like this issue is more about throwing out array semantics entirely, and defining the derivative only wrt the fill value, for which the earlier examples indeed don't apply.
This case is a bit different because gradients of FillArrays themselves are now unambiguous, but there are still ways in which their ambiguity can leak out. Contrived example (since you mentioned upstream rrules in the original issue):
@adjoint function (a * b)
y = a * b
y = reduce(==, y) ? Fill(mean(y), size(y)) : y
y, c̄ -> (c̄'b, a'c̄)
end
gradient(x -> (W*x)[1], x)
If W*x
produces a fill, it will get a Composite
gradient back. It obviously has to be able to deal with this, which means it has to interpret the gradient as an array somehow, with an inherent ambiguity as to what elements of y
are responsible for the gradient. You don't have to choose to interpret the Composite
as a FillArray
here – you could use a random array with the same sum – but it raises the same issues as if you did. In this case it's actually the gradients of two innocent bystanders – W
and x
– that end up being ambiguous.
Even if it were easy to buy Fill
out of all the abstract array gradients, it's hard to see a way around this, since in generic code pretty much any adjoint can produce a Fill
at any time, and those generic adjoints then all need to be updated to know how to convert Composite{Fill}
to something like an array.
from chainrulescore.jl.
Firstly, apologies for the slow response.
It depends on what the exact proposal is; the discussion on the Zygote issue indicated a mixing of array and composite adjoints (i.e. using composites for efficiency where possible but keeping array semantics broadly), so my comments are mainly aimed at attempts to preserve current semantics with better efficiency. I'm not certain but it seems like this issue is more about throwing out array semantics entirely, and defining the derivative only wrt the fill value, for which the earlier examples indeed don't apply.
I think this is what this issue is more about now -- your comments on the ambiguities inherent in mixing composite and array types for differentials seem pretty compelling.
If W*x produces a fill, it will get a Composite gradient back. It obviously has to be able to deal with this, which means it has to interpret the gradient as an array somehow, with an inherent ambiguity as to what elements of y are responsible for the gradient. You don't have to choose to interpret the Composite as a FillArray here – you could use a random array with the same sum – but it raises the same issues as if you did. In this case it's actually the gradients of two innocent bystanders – W and x – that end up being ambiguous.
I don't entirely follow - specifically it's unclear to me why its valid to choose any array with the same sum if you receive a composite.
On the topic of using a Fill
as the differential of a Fill
, what if I have a Fill
whose eltype is ComplicatedType
for which division-by-real isn't defined, and hence the mean isn't defined? In this case I can't see how you could represent the differential as a Fill
.
from chainrulescore.jl.
Assuming you buy the xs[i] ≡ mean(xs)
argument, then xs[i]
is also equivalent to any weighted mean of xs
(where the weights sum to one). With random weights you get a random gradient with the same sum; and we can reduce all Fill
operations to xs[i]
.
We should be clear, though, that the mean(xs)
example is more of a tool for intuition in a particular case than something fundamental. The basic issue is that if we have only a Composite
gradient and we need an element-wise gradient, the Composite
only tells us the sum of the adjoints of all elements, and nothing about how individual elements contributed; which means we either have to throw an error here or pick a distribution somewhat arbitrarily (and we might motivate this by analogy to a specific operation, or by what's convenient performance-wise).
So the fact that sum(xs)/length(xs)
is a weird operation for, say, Date
s is not that important; we could probably contrive something that produces the equivalent gradient, but even if not, the idea that the Composite
adjoint represents the sum of the adjoints of the array elements (whether represented as milliseconds or composites) is still well defined, and we still face the same issues when it comes to array operations.
from chainrulescore.jl.
Assuming you buy the xs[i] ≡ mean(xs) argument, then xs[i] is also equivalent to any weighted mean of xs (where the weights sum to one). With random weights you get a random gradient with the same sum; and we can reduce all Fill operations to xs[i].
Okay, I think I follow.
I think my view on this is that we should never need an element-wise gradient for a Fill
. Certainly, it means that you can't just expect rrule
defined for AbstractArray
s to necessarily work.
Taking your example, the rule-writer has the information required to know unambiguously what the gradient should be by virtue of the fact that mean
was called internally inside *
-- all of the gradients should be the same in the Fill
case. If you had not written a rule for this, the AD would have been able to figure that out. I know I've taken your example quite literally, but it's unclear to me that this demonstrates ambiguity leakage.
from chainrulescore.jl.
The core point is really that switching to Composite
gradients changes the gradients we get for W
and x
, in some cases, compared to the current situation; this is what (it seems) we want to avoid. Ok, sure, you can define away the ambiguity by looking at the implementation of the primal -- but if we're redefining what counts as a correct gradient we may as well make it convenient. Does the adjoint author have to keep an eye on the package providing the primal, and change the adjoint if mean(y)
gets changed to y[1]
?
More broadly, I'm not imagining that people will actually write code like this, but Fill
s can be emitted from generic code perhaps as a result of interactions between multiple packages. If an innocuous-looking adjoint happens to produce a Fill
, the gradient will error out. To fix that we have to effectively audit the implementation (potentially a lot of code written by other people) in order to define the gradient of the Fill
and then have dispatch in the adjoint to handle that case (and later make sure it stays in sync with the package). At the end of this we are still changing gradients from their dense array equivalents, and in a way that doesn't have any semantic guarantees as package implementations change.
from chainrulescore.jl.
Related Issues (20)
- FAQ Broken Links HOT 3
- (to be deleted)
- Get rid of `MethodError: no method matching iterate(::Nothing)` HOT 1
- tangent_type(primal_type)
- typos in geometric descriptions in the docs HOT 3
- `norm(NoTangent())` causes StackOverflow HOT 8
- support substraction on Tangent{T} HOT 2
- `tangent * tangent` and `dot(tangent, tangent)` can return tangents.
- ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals) HOT 3
- using ChainRulesCore changes type promotion logic HOT 1
- Hash equality disagreement for MutableTangent on 32bit (/x86 CI)
- Errror in accumulate when I have one argument as a tuple HOT 2
- Removing / with tangent denominator
- Factor out backing and construct into a separate lightweight package? HOT 6
- Implement negation on StructuralTangents
- Fifteen tests broken, probably by design, in ChainRulesCore.jl HOT 1
- Defining rrules for high order functions HOT 1
- How to specify rules for kwargs HOT 1
- Add Enzyme as tool which can import HOT 1
- `@non_differentiable` should use identical pullbacks when possible HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from chainrulescore.jl.