Comments (7)
This is fixed by #160. In general, it is a bad idea to just differentiate through find_alpha
and much more efficient to implement the derivatives explicitly, as done in this PR.
from bijectors.jl.
I can't run your example and it's a bit difficult to comment without knowing the exact error message. It is also a bit unclear to me which packages you used here - did you load DistributionsAD?
from bijectors.jl.
@devmotion I've updated the example with the full code.
Yes, I've loaded DistributionsAD
using Turing, Bijectors, Flux, ProgressMeter, DistributionsAD
from bijectors.jl.
Thanks, now I can run the code. The error is caused by AD problems of the Roots package. The inverse of planar layers is computed using a root-finding algorithm in the Roots package (see
Bijectors.jl/src/bijectors/planar_layer.jl
Line 122 in a854144
from bijectors.jl.
BTW the different AD backends (Tracker, ForwardDiff, ReverseDiff, and Zygote) all have different advantages and disadvantages and usually the optimal choice depends on the problem and your implementation. Tracker is not discontinued, similar to ForwardDiff it is solid and maintained but it is not planned to add any major new features.
from bijectors.jl.
Thanks! I will open a issue on Zygote to see if there is any plans to support it, as it's has some other advantages.
from bijectors.jl.
just to update the error with the latest packages updates.
Compiling Tuple{typeof(Bijectors.find_alpha),Float64,Float64,Float64}: try/catch is not supported.
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] instrument(::IRTools.Inner.IR) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:89
[3] #Primal#20 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:170 [inlined]
[4] Zygote.Adjoint(::IRTools.Inner.IR; varargs::Nothing, normalise::Bool) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:283
[5] _lookup_grad(::Type{T} where T) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/emit.jl:101
[6] #s2937#1244 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:37 [inlined]
[7] #s2937#1244(::Any, ::Any, ::Any) at ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:527
[9] #1079 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/broadcast.jl:150 [inlined]
[10] #3844#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[11] (::Zygote.var"#150#151"{Zygote.var"#3844#back#1082"{Zygote.var"#1079#1081"{typeof(∂(find_alpha))}},Tuple{NTuple{4,Nothing},Tuple{Nothing}}})(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/lib.jl:191
[12] #1693#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[13] broadcasted at ./broadcast.jl:1263 [inlined]
[14] Inverse at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/planar_layer.jl:117 [inlined]
[15] (::typeof(∂(λ)))(::Array{Float64,1}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[16] logabsdetjac at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/interface.jl:85 [inlined]
[17] (::typeof(∂(logabsdetjac)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[18] forward at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/interface.jl:98 [inlined]
[19] (::typeof(∂(forward)))(::NamedTuple{(:rv, :logabsdetjac),Tuple{Array{Float64,1},Float64}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[20] macro expansion at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/composed.jl:0 [inlined]
[21] forward at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/composed.jl:219 [inlined]
[22] (::typeof(∂(forward)))(::NamedTuple{(:rv, :logabsdetjac),Tuple{Array{Float64,1},Float64}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[23] _logpdf at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/transformed_distribution.jl:105 [inlined]
[24] (::typeof(∂(_logpdf)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[25] #323 at /Users/luccazenobio/.julia/packages/DistributionsAD/HvoZ3/src/zygote.jl:85 [inlined]
[26] (::typeof(∂(λ)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[27] #502 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:187 [inlined]
[28] #3 at ./generator.jl:36 [inlined]
[29] iterate at ./generator.jl:47 [inlined]
[30] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},Array{Float64,1}}},Base.var"#3#4"{Zygote.var"#502#506"}}) at ./array.jl:686
[31] map at ./abstractarray.jl:2248 [inlined]
[32] (::Zygote.var"#501#505"{Array{typeof(∂(λ)),1}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:187
[33] #2537#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[34] #322 at /Users/luccazenobio/.julia/packages/DistributionsAD/HvoZ3/src/zygote.jl:85 [inlined]
[35] (::typeof(∂(#322)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[36] #41 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:40 [inlined]
[37] #532#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[38] loss at ./In[13]:12 [inlined]
[39] #31 at ./In[13]:16 [inlined]
[40] (::typeof(∂(λ)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[41] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:172
[42] macro expansion at ./In[13]:17 [inlined]
[43] macro expansion at /Users/luccazenobio/.julia/packages/ProgressMeter/GhAId/src/ProgressMeter.jl:762 [inlined]
[44] nf_train(::TransformedDistribution{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},Composed{Tuple{PlanarLayer{Array{Float64,1},Array{Float64,1}},LeakyReLU{Float64,1},RadialLayer{Array{Float64,1},Array{Float64,1}},LeakyReLU{Float64,1},PlanarLayer{Array{Float64,1},Array{Float64,1}}},1},Multivariate}, ::Array{Float64,2}, ::ADAM, ::Zygote.Params, ::Int64) at ./In[13]:15
[45] top-level scope at In[13]:21
[46] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
from bijectors.jl.
Related Issues (20)
- Adding bijectors for OrderStatistic and JointOrderStatistics HOT 1
- Add API function to retrieve size of bijector output from bijector input HOT 1
- rational quadratic flows not supporting Float32 input HOT 1
- What to do with `CorrBijector` ? HOT 1
- Improve `PDVecBijector`
- Matrix factorization bijectors HOT 4
- Domain Error for VecCholeskyBijector bijector when calling logabsdetjac HOT 4
- Question on simplex bijector implementation HOT 9
- Can't apply Bijectors.ordered to TDist() and MvTDist() HOT 1
- Incorrect bijector for heterogeneous Product distribution HOT 3
- Radial flow to a simplex HOT 5
- Stackoverflow in custom bijector HOT 2
- Missing implementation of `Bijectors.bijector` for `arraydist` distributions. HOT 1
- Bijectors.ordered and MvLogNormal interaction .. only supported for unconstrained distributions. HOT 1
- `TruncatedBijectors` not defined in `Distributions` extension
- support ProductDistribution HOT 3
- Fixes to correlation bijectors
- Improve `with_logabsdet_jacobian` performance for `SimplexBijector` HOT 1
- Tests are failing for `VecCorrBijector` in _very_ rare scenarios
- Add Tapir to Bijectors tests.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bijectors.jl.