Giter Site home page Giter Site logo

Comments (4)

devmotion avatar devmotion commented on July 17, 2024 1

AFAICT the pullback should return a tuple of length N + 1. Alternatively with callable structs:

using ChainRulesCore

struct NonDiffPullback{N} end
(pb::NonDiffPullback{N})(::Vararg{Any,N}) where {N} = ntuple(Returns(NoTangent()), Val(N + 1))

from chainrulescore.jl.

devmotion avatar devmotion commented on July 17, 2024 1

We can use Returns since the package depends on Compat >= 3.40 and it seems in version 3.35 support for Returns was added.

from chainrulescore.jl.

nsajko avatar nsajko commented on July 17, 2024

I'm not good with macros so I probably won't be tackling this.

Actually, I think I got this. I think the issue can be fixed without relying on macros too much using make_pullback_for_non_differentiable:

using ChainRulesCore

function make_pullback_for_non_differentiable(::Val{N}) where {N}
    Vararg{Any,N}  # throw early for invalid `N`, must be nonnegative `Int`
    function pullback_for_non_differentiable(::Vararg{Any,N})
        f = _ -> NoTangent()
        ntuple(f, Val(N))
    end
end

using Test

@testset "`make_pullback_for_non_differentiable`" begin
    f = make_pullback_for_non_differentiable
    @testset "throws on invalid input" begin
        @test_throws Exception f(Val(0.0))
        @test_throws Exception f(Val(-1))
    end
    @testset "identical objects" begin
        for i  0:5
            v = Val(i)
            @test f(v) === f(v)
        end
    end
    @testset "dispatch" begin
        pullback = f(Val(2))
        @test_throws MethodError pullback()
        @test_throws MethodError pullback(1)
        @test (NoTangent(), NoTangent()) === pullback(1, 2)
        @test_throws MethodError pullback(1, 2, 3)
    end
end

from chainrulescore.jl.

nsajko avatar nsajko commented on July 17, 2024

Thanks, you're right, there's an off-by-one error. But note we can't use Returns until support for Julia v1.6 is dropped.

from chainrulescore.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.