Giter Site home page Giter Site logo

Comments (8)

lukas-weber avatar lukas-weber commented on August 16, 2024 1

Hmm, it is an pullback for an integrator which I implement by calling the pullback of the integrand and integrating the resulting tangent using the integrator. The integrator checks convergence by calling norm on the tangent. So apart from this convergence check, the pullback is linear in its inputs.

In other words, when summing an infinite series of tangents, it’s good to have a way to tell when the series is converged and I guess for that you need a notion of distance. For tangents without any NoTangent entries hidden inside, this distance is induced by the current implementation of norm. It might sense to define that distance also between tangents that somewhere in them have a NoTangent entry.

Leaving everything else the same, defining norm(::NoTangent) = 0 would give the correct distance for the application above, but I’m not sure if it’s a good idea in general.

I think norm(::NoTangent) = NoTangent() as you proposed is also sensible as it would satisfy norm(NoTangent())^2 === NoTangent()*NoTangent() === dot(NoTangent(), NoTangent()). Although I wonder if dot(NoTangent(), NoTangent()) returning NoTangent() as it does currently is the best definition, because the result of the scalar product should not be a member of the tangent space.

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 16, 2024

I believe this should just return NoTangent() because norm is a linear operator.

from chainrulescore.jl.

sethaxen avatar sethaxen commented on August 16, 2024

norm is not a linear map. e.g. iszero(norm(x) + norm(-x)) only when iszero(x)

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 16, 2024

Fair point.
On that basis, why are you calling norm on a tangent?
Since it is not a linear map, that suggests your pullback (or pushforward) is incorrect.
Since pullbacks (and pushforwards) are always linear maps on their inputs.

from chainrulescore.jl.

sethaxen avatar sethaxen commented on August 16, 2024

@lukas-weber are you able to share your pullback code with us?

from chainrulescore.jl.

lukas-weber avatar lukas-weber commented on August 16, 2024

Sure, here is a minimal example using QuadGK as the integrator.

using ChainRulesCore
using Zygote
using QuadGK
using LinearAlgebra

struct TangentWrapper{T}
    tangent::T
end

Base.:(+)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + b.tangent)
Base.:(-)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + -b.tangent)
Base.:(*)(t::TangentWrapper, f::Number) = TangentWrapper(f * t.tangent)
LinearAlgebra.norm(t::TangentWrapper) = norm(t.tangent)

function integrate(func)
    return quadgk(x->exp(-x^2) * func(x), -Inf, Inf)[1]
end

function ChainRulesCore.rrule(config::RuleConfig, ::typeof(integrate), func)
    y = integrate(func)
    project = ProjectTo(func)

    function integrate_pullback(Δy)
        function dfunc_integrand(x)
            _, inner_rrule = ChainRulesCore.rrule_via_ad(config, func, x)
            return TangentWrapper(inner_rrule(Δy)[1])
        end

        return NoTangent(), @thunk(project(integrate(dfunc_integrand).tangent))
    end

    return y, integrate_pullback
end


function test()
    a = 10
    b = 1.0
    has_no_tangent(x) = sum(fill(x, a))

    example_func(x) = cos(b * x) + has_no_tangent(x)

    @show integrate(example_func)
    @show gradient(f->integrate(f), example_func)
end

The TangentWrapper type was introduced to provide the binary - method which was also missing.

As such, the code works as long as the AD provided tangent of example_func does not contain any NoTangents down its tree. The sum(fill(x,a)) closure is some strange thing designed to achieve that. In practice it happened to me with a more complicated integrand operating on structured types.

(This was inspired by a similar rrule that was already implemented in Integrals.jl, but not fit for my use case)

from chainrulescore.jl.

oxinabox avatar oxinabox commented on August 16, 2024

I think this makes sense, norm is often used in exactly this way.
I think we can add this overload.
At least for AbstractZero subtypes and NotImplemented
Feel encourages to make a PR.

from chainrulescore.jl.

lukas-weber avatar lukas-weber commented on August 16, 2024

Okay, I’ll take a stab at it.

from chainrulescore.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.