Comments (12)
Alternatively, one may want to use a cheap trick to avoid hitting Inf
at all by using logit(z / (1 + eps(T)))
instead of logit(z)
. This is a normal mathematical function, so it should play well with autodiff, cheap but convenient.
julia> logit(1.0/(1.0 + eps(Float64)))
36.04365338911715
from bijectors.jl.
And changing y[k] = zero(T)
to y[k] = 1 - sum(x)
fixes the Jacobian problem.
from bijectors.jl.
The subtraction trick doesn't seem to solve all problems. I will check it again tomorrow, it's way past bed time here :)
from bijectors.jl.
I think the cause, or at least one cause of the problem is this:
julia> using Bijectors
julia> link(Dirichlet([1., 1., 1.]), [1.0, 0.0, 0.0])
3-element Array{Float64,1}:
Inf
NaN
0.0
The current implementation of link
for Dirichlet
is only stable if the sum is never 1 until the very last element of x
, otherwise, Inf
s and soon after NaN
s start showing up.
from bijectors.jl.
I am no expert, but arguably, the correct / most sensible value of logit(1)
is Inf
. And given the following implementation, in the limit as x[k]
goes to 0, assuming sum(x[i] for i in 1:k-1)
goes to 1, i.e. sum_tmp = 1 - x[k]
goes to 1, z
will tend to 0/0
. Using L'Hopital's rule, we can see that z
really tends to 1, so y[k]
tends to Inf
. Applying a similar analysis to all future k
s, and I think it makes sense to make all y[k]
equal to Inf
starting from the first k
where sum_tmp == 1
, assuming of course the sum of x
is 1. In a world where machine precision comes into play, and we really want real numbers not Inf
s, we can use floatmax(T)
, which is the highest number representable by T
.
function link(d::SimplexDistribution, x::AbstractVector{T}) where T<:Real
y, K = similar(x), length(x)
sum_tmp = zero(T)
z = x[1]
y[1] = StatsFuns.logit(z) - log(one(T) / (K - 1))
@inbounds for k in 2:K - 1
sum_tmp += x[k - 1]
z = x[k] / (one(T) - sum_tmp)
y[k] = StatsFuns.logit(z) - log(one(T) / (K - k))
end
y[K] = zero(T)
return y
end
The only problem is that I think the if statement may mess up with the autodiff.
from bijectors.jl.
Could you provide a toy implementation to make it more clear what you mean please @mohamed82008 ?
The only problem is that I think the if statement may mess up with the autodiff.
We could just hand code the reverse-mode sensitivity for this operation (which we should probably do anyway for performance reasons). That way we can handle the limits ourselves to ensure that sensible things happen.
from bijectors.jl.
Could you provide a toy implementation to make it more clear what you mean please @mohamed82008 ?
function link(d::SimplexDistribution, x::AbstractVector{T}) where T<:Real
y, K = similar(x), length(x)
sum_tmp = zero(T)
z = x[1]
y[1] = clamp(StatsFuns.logit(z) - log(one(T) / (K - 1)), -floatmax(T), floatmax(T))
@inbounds for k in 2:K - 1
sum_tmp += x[k - 1]
z = x[k] / (one(T) - sum_tmp)
isnan(z) && (z = one(T))
y[k] = clamp(StatsFuns.logit(z) - log(one(T) / (K - k)), -floatmax(T), floatmax(T))
end
y[K] = zero(T)
return y
end
from bijectors.jl.
Ah I see. Yeah, clamp
will zero the gradients (assuming a sane implementation of clamp
) when the input is outside (-floatmax(T), floatmax(T))
, so this probably isn't the best idea. The eps
trick looks reasonable, but I don't think your proposal will handle the case where x[k]
is zero. If instead replace our logit
s with
logit((z + eps(T)) / (one(T) + 2 * eps(T)))
we have something that works for both zeros and ones. We should also modify the corresponding inverse link function to ensure it's actually the inverse.
from bijectors.jl.
This seems to work:
function link(d::SimplexDistribution, x::AbstractVector{T}) where T<:Real
y, K = similar(x), length(x)
sum_tmp = zero(T)
z = x[1]
y[1] = StatsFuns.logit(z/(1 + eps(T))) - log(one(T) / (K - 1))
@inbounds for k in 2:K - 1
sum_tmp += x[k - 1]
z = (x[k] + eps(T)) / (one(T) - sum_tmp + eps(T))
@show z
y[k] = StatsFuns.logit(z/(1 + eps(T))) - log(one(T) / (K - k))
end
y[K] = zero(T)
return y
end
from bijectors.jl.
What if x[1] == 0
?
from bijectors.jl.
Oh I see what you mean.
from bijectors.jl.
Yeah, the reason I went for that particular transform is because it takes 0
to eps(T)
and 1
to 1 - eps(T)
in a continuous manner. So provided that we don't get any values of z
outside [0, 1]
we should be fine.
from bijectors.jl.
Related Issues (20)
- filldist, up1 not defined HOT 6
- Adding bijectors for OrderStatistic and JointOrderStatistics HOT 1
- Add API function to retrieve size of bijector output from bijector input HOT 1
- rational quadratic flows not supporting Float32 input HOT 1
- What to do with `CorrBijector` ? HOT 1
- Improve `PDVecBijector`
- Matrix factorization bijectors HOT 4
- Domain Error for VecCholeskyBijector bijector when calling logabsdetjac HOT 4
- Question on simplex bijector implementation HOT 9
- Can't apply Bijectors.ordered to TDist() and MvTDist() HOT 1
- Incorrect bijector for heterogeneous Product distribution HOT 3
- Radial flow to a simplex HOT 5
- Stackoverflow in custom bijector HOT 2
- Missing implementation of `Bijectors.bijector` for `arraydist` distributions. HOT 1
- Bijectors.ordered and MvLogNormal interaction .. only supported for unconstrained distributions. HOT 1
- `TruncatedBijectors` not defined in `Distributions` extension
- support ProductDistribution HOT 3
- Fixes to correlation bijectors
- Improve `with_logabsdet_jacobian` performance for `SimplexBijector` HOT 1
- Tests are failing for `VecCorrBijector` in _very_ rare scenarios
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bijectors.jl.