Comments (4)
AFAICT the pullback should return a tuple of length N + 1
. Alternatively with callable structs:
using ChainRulesCore
struct NonDiffPullback{N} end
(pb::NonDiffPullback{N})(::Vararg{Any,N}) where {N} = ntuple(Returns(NoTangent()), Val(N + 1))
from chainrulescore.jl.
We can use Returns
since the package depends on Compat >= 3.40 and it seems in version 3.35 support for Returns was added.
from chainrulescore.jl.
I'm not good with macros so I probably won't be tackling this.
Actually, I think I got this. I think the issue can be fixed without relying on macros too much using make_pullback_for_non_differentiable
:
using ChainRulesCore
function make_pullback_for_non_differentiable(::Val{N}) where {N}
Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int`
function pullback_for_non_differentiable(::Vararg{Any,N})
f = _ -> NoTangent()
ntuple(f, Val(N))
end
end
using Test
@testset "`make_pullback_for_non_differentiable`" begin
f = make_pullback_for_non_differentiable
@testset "throws on invalid input" begin
@test_throws Exception f(Val(0.0))
@test_throws Exception f(Val(-1))
end
@testset "identical objects" begin
for i ∈ 0:5
v = Val(i)
@test f(v) === f(v)
end
end
@testset "dispatch" begin
pullback = f(Val(2))
@test_throws MethodError pullback()
@test_throws MethodError pullback(1)
@test (NoTangent(), NoTangent()) === pullback(1, 2)
@test_throws MethodError pullback(1, 2, 3)
end
end
from chainrulescore.jl.
Thanks, you're right, there's an off-by-one error. But note we can't use Returns
until support for Julia v1.6 is dropped.
from chainrulescore.jl.
Related Issues (20)
- FAQ Broken Links HOT 3
- (to be deleted)
- Get rid of `MethodError: no method matching iterate(::Nothing)` HOT 1
- tangent_type(primal_type)
- typos in geometric descriptions in the docs HOT 3
- `norm(NoTangent())` causes StackOverflow HOT 8
- support substraction on Tangent{T} HOT 2
- `tangent * tangent` and `dot(tangent, tangent)` can return tangents.
- ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals) HOT 3
- using ChainRulesCore changes type promotion logic HOT 1
- Hash equality disagreement for MutableTangent on 32bit (/x86 CI)
- Errror in accumulate when I have one argument as a tuple HOT 2
- Removing / with tangent denominator
- Factor out backing and construct into a separate lightweight package? HOT 6
- Implement negation on StructuralTangents
- Fifteen tests broken, probably by design, in ChainRulesCore.jl HOT 1
- Defining rrules for high order functions HOT 1
- How to specify rules for kwargs HOT 1
- Add Enzyme as tool which can import HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from chainrulescore.jl.