Comments (8)
Hmm, it is an pullback for an integrator which I implement by calling the pullback of the integrand and integrating the resulting tangent using the integrator. The integrator checks convergence by calling norm
on the tangent. So apart from this convergence check, the pullback is linear in its inputs.
In other words, when summing an infinite series of tangents, it’s good to have a way to tell when the series is converged and I guess for that you need a notion of distance. For tangents without any NoTangent
entries hidden inside, this distance is induced by the current implementation of norm
. It might sense to define that distance also between tangents that somewhere in them have a NoTangent
entry.
Leaving everything else the same, defining norm(::NoTangent) = 0
would give the correct distance for the application above, but I’m not sure if it’s a good idea in general.
I think norm(::NoTangent) = NoTangent()
as you proposed is also sensible as it would satisfy norm(NoTangent())^2 === NoTangent()*NoTangent() === dot(NoTangent(), NoTangent())
. Although I wonder if dot(NoTangent(), NoTangent())
returning NoTangent()
as it does currently is the best definition, because the result of the scalar product should not be a member of the tangent space.
from chainrulescore.jl.
I believe this should just return NoTangent()
because norm
is a linear operator.
from chainrulescore.jl.
norm
is not a linear map. e.g. iszero(norm(x) + norm(-x))
only when iszero(x)
from chainrulescore.jl.
Fair point.
On that basis, why are you calling norm
on a tangent?
Since it is not a linear map, that suggests your pullback (or pushforward) is incorrect.
Since pullbacks (and pushforwards) are always linear maps on their inputs.
from chainrulescore.jl.
@lukas-weber are you able to share your pullback code with us?
from chainrulescore.jl.
Sure, here is a minimal example using QuadGK as the integrator.
using ChainRulesCore
using Zygote
using QuadGK
using LinearAlgebra
struct TangentWrapper{T}
tangent::T
end
Base.:(+)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + b.tangent)
Base.:(-)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + -b.tangent)
Base.:(*)(t::TangentWrapper, f::Number) = TangentWrapper(f * t.tangent)
LinearAlgebra.norm(t::TangentWrapper) = norm(t.tangent)
function integrate(func)
return quadgk(x->exp(-x^2) * func(x), -Inf, Inf)[1]
end
function ChainRulesCore.rrule(config::RuleConfig, ::typeof(integrate), func)
y = integrate(func)
project = ProjectTo(func)
function integrate_pullback(Δy)
function dfunc_integrand(x)
_, inner_rrule = ChainRulesCore.rrule_via_ad(config, func, x)
return TangentWrapper(inner_rrule(Δy)[1])
end
return NoTangent(), @thunk(project(integrate(dfunc_integrand).tangent))
end
return y, integrate_pullback
end
function test()
a = 10
b = 1.0
has_no_tangent(x) = sum(fill(x, a))
example_func(x) = cos(b * x) + has_no_tangent(x)
@show integrate(example_func)
@show gradient(f->integrate(f), example_func)
end
The TangentWrapper
type was introduced to provide the binary -
method which was also missing.
As such, the code works as long as the AD provided tangent of example_func
does not contain any NoTangent
s down its tree. The sum(fill(x,a))
closure is some strange thing designed to achieve that. In practice it happened to me with a more complicated integrand operating on structured types.
(This was inspired by a similar rrule that was already implemented in Integrals.jl, but not fit for my use case)
from chainrulescore.jl.
I think this makes sense, norm is often used in exactly this way.
I think we can add this overload.
At least for AbstractZero
subtypes and NotImplemented
Feel encourages to make a PR.
from chainrulescore.jl.
Okay, I’ll take a stab at it.
from chainrulescore.jl.
Related Issues (20)
- FAQ Broken Links HOT 3
- (to be deleted)
- Get rid of `MethodError: no method matching iterate(::Nothing)` HOT 1
- tangent_type(primal_type)
- typos in geometric descriptions in the docs HOT 3
- support substraction on Tangent{T} HOT 2
- `tangent * tangent` and `dot(tangent, tangent)` can return tangents.
- ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals) HOT 3
- using ChainRulesCore changes type promotion logic HOT 1
- Hash equality disagreement for MutableTangent on 32bit (/x86 CI)
- Errror in accumulate when I have one argument as a tuple HOT 2
- Removing / with tangent denominator
- Factor out backing and construct into a separate lightweight package? HOT 6
- Implement negation on StructuralTangents
- Fifteen tests broken, probably by design, in ChainRulesCore.jl HOT 1
- Defining rrules for high order functions HOT 1
- How to specify rules for kwargs HOT 1
- Add Enzyme as tool which can import HOT 1
- `@non_differentiable` should use identical pullbacks when possible HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from chainrulescore.jl.