Comments (17)
I like the StaticBijector idea!
Great! So we go for the "Dimensionality" approach now, and then we can introduce StaticBijector
later on? I think StaticBijector
requires some more thought and work, so personally I don't want to rush that.
from bijectors.jl.
I agree with Tor on this one. Here is an excerpt from Slack:
No encoding the size in the type is not good imo. It will bite us later. If the user passes x::StaticArray, we can specialize on the size of x but there is no need for the bijector to be limited to a certain size imo.
I am not sure checking size-discrepancy between bijectors in Composed upon construction is the right thing either. For instance the bijector may be a linear map with the map matrix size being runtime info.
Or a scale which makes sense regardless of the dimension.
from bijectors.jl.
I like the StaticBijector
idea!
from bijectors.jl.
New possible solution:
# `N = 0, 1, 2, ...` represents the expected input/output (they're the same because bijection)
abstract type Bijector{N} end
# ...
# Can implement `Log` in the following way
struct Log{N} <: Bijector{N} end
Log() = Log{0}()
# can also specialize this further to catch dim-discrepancies at compile-time
(b::Log)(x) = @. log(x)
inv(b::Log{N}) where {N} = Exp{N}()
logabsdetjac(b::Log{0}, x::Real) = - log(x)
_logabsdetjac(b::Log{0}, x::Real) = - log(x)
_logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log.(x))
_logabsdetjac_batch(b::Log{0}, x::AbstractVector) = _logabsdetjac.(b, x)
_logabsdetjac_batch(b::Log, x::AbstractArray) = mapslices(z -> _logabsdetjac(b, z), x; dims = 1)
# Can make this `Bijector{N1}` and have user override `_logabsdetjac`
# and `_logabsdetjac_batch` instead of `logabsdetjac`. Can also
# provide default implementation of `_logabsdetjac_batch` using `mapslices`
@generated function logabsdetjac(
b::Log{N1},
x::AbstractArray{T2, N2}
) where {N1, T2, N2}
if N1 == N2
return :(_logabsdetjac(b, x))
elseif N1 + 1 == N2
return :(_logabsdetjac_batch(b, x))
else
return :(throw(MethodError(logabsdetjac, (b, x))))
end
end
This results in
julia> using Bijectors
[ Info: Recompiling stale cache file /home/tor/.julia/compiled/v1.0/Bijectors/39uFz.ji for Bijectors [76274a88-744f-5084-9051-94815aaf08c4]
julia> using Bijectors: Log
julia> b = Log{0}()
Log{0}()
julia> b(1.0) # ✓Correctly treated as single input
0.0
julia> b([1.0, 2.0]) # ✓Correctly treated as a batch of inputs
2-element Array{Float64,1}:
0.0
0.6931471805599453
julia> b = Log{1}() # expects dim 1 as output
Log{1}()
julia> b([1.0, 2.0]) # ✓Correctly transformed as a single input
2-element Array{Float64,1}:
0.0
0.6931471805599453
julia> logabsdetjac(b, 1.0) # ✓Correctly not implemented
ERROR: MethodError: no method matching logabsdetjac(::Log{1}, ::Float64)
Closest candidates are:
logabsdetjac(::Log{0}, ::Real) at /home/tor/.julia/dev/Bijectors/src/interface.jl:314
logabsdetjac(::Inversed{#s43,N} where N where #s43<:Bijector, ::Any) at /home/tor/.julia/dev/Bijectors/src/interface.jl:77
logabsdetjac(::ADBijector, ::Real) at /home/tor/.julia/dev/Bijectors/src/interface.jl:135
...
Stacktrace:
[1] top-level scope at none:0
julia> logabsdetjac(b, [1.0, 2.0]) # ✓Correctly treated as a single input
-0.6931471805599453
julia> logabsdetjac(b, [1.0 2.0 3.0; 3.0 4.0 5.0]) # ✓Correctly treated as a batch of inputs
1×3 Array{Float64,2}:
-1.09861 -2.07944 -2.70805
We can then do this for the rest of the bijectors. Worth noting that many of the bijectors don't care about the dimension, e.g. SimplexBijector
is then simply implemented as
struct SimplexBijector{T} <: Bijector{1} where {T} end
since it only makes sense as a Bijector
treating 1D vector as a single input and 2D matrix as batch of inputs.
All in all, I think it looks a bit nasty internally but for the user it's going be quite nice:)
EDIT: Another nice consequence is that we can catch-dimensionality discrepancies in Composed
at compile-time, i.e. not possible to compose bijectors expecting different inputs-dimensions!
EDIT 2: The use of @generated
is fairly unnecessary in the Log
case:
logabsdetjac(b::Log{0}, x::Real) = -log(x)
logabsdetjac(b::Log{0}, x::AbstractVector) = -log.(x)
logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log.(x))
logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log.(x); dims = 1))
This replaces all the business of _logabsdetjac
and _logabsdetjac_batch
👍
Also, in this particular case you could just use broadcasting for batches of Real
but you'd still have a problem with matrix-valued, etc. where it would look similar but broadcasting wouldn't do it.
from bijectors.jl.
So I think there are two main approaches to solving this, and I'd love to hear what people think.
Size
Explicitly provide size of expected input for all bijectors you construct (you also need to do it for the ones which are only well-defined on a specific type of input, e.g. SimplexBijector
).
abstract type Bijector{Size} end
# Example of construction:
# Expects input `x` with `size(x) = (3, )`
# OR `size(x) = (3, n)` where `n` is the number of inputs ("batch-computation")
Logit{(3, )}()
# Can also make it a bit nicer by adding constructors, e.g.:
Logit(size::Tuple) = Logit{size}()
Pros
- Catch size discrepancies for compositions of bijectors during composition (compile-time) rather than during evaluation (run-time)
- Can do checks at compile-time when using
StaticArray
(but whether you're catching size discrepancies when actually evaluting aBijector
on an input at compile-time or run-time doesn't really matter as far as I can tell) - The upcoming
Stacked
bijector in #36 provides a way of stacking 0- and 1-dimensional bijectors to act on a single vector, e.g. applyLogit
tox[1:3]
andSimplexBijector
tox[4:end]
or something. With static sizes we'd be able to provide sane defaults to the ranges, e.g. stackingLogit{(N1, )}
andSimplexBijector{(N2, )}
will applyLogit
tox[1:N1]
andSimplexBijector
tox[N1 + 1:end]
. Can also do this for more than just two bijectors. At the moment, the user has to provide the ranges unless they're all 0-dimensional in which case we'll just do elementwise application.
Cons
- It puts a lot of effort on the user. You need to do this explicitly for all bijectors.
- When working with non-static arrays you'll in many cases have to reconstruct the
Bijector
on each input to ensure dimensionality of theBijector
is correct. This can be inefficient (and annoying) in certain cases. - It can also get fairly ugly internally, which might also be a problem for the user if they want to implement their own
Bijector
Dimensionality
Instead of the size
of the expected input we can parameterize Bijector
by the dimensionality of the expected input, i.e. length(size)
. This is basically like Univariate
and Multivariate
from Distributions.jl but in full generality.
abstract type Bijector{N} end
# Example of construction:
Logit{1}() # expects vector input
Could alternatively wrap N
in some type Dim
, e.g. Dim{N}
instead, or something like that to give N
more semantic meaning.
In a lot of cases a bijector is only well-defined for a particular dimensionality, e.g. SimplexBijector
is only well-defined on a 1D vector. Such cases will then subtype Bijector{specific_value_of_N}
, e.g.
struct SimplexBijector <: Bijector{1} end
Pros
- Catching dimensionality (not _size) discrepancies at composition rather than at evaluation
- From a user-perspective, there's not much change. Only certain bijectors require the user to provide the dimensionality of the expected input since most are only well-defined on, say, a 1D vector.
- Often easy to provide sane defaults, e.g.
Log() = Log{0}()
. - Provide default constructor for
Stacked
in the case where the bijectors to stack are all 0-dimensional
Cons
- For a small subset of bijectors, the user has to specify the dimensions.
from bijectors.jl.
Personally I prefer the "Dimensionality" approach 🤷♂️
from bijectors.jl.
Seems that for the Stacked
purposes the "Size" approach is better?
from bijectors.jl.
Yeah, but I'm also afraid that Size
will cause issues with type-stability due to needing run-time information for construction (unless using StaticArray
, which is worse than standard Array
in higher dimensions). So for Stacked
it would indeed be better in the case where we construct the bijector at the top-level, but as soon as we construct it inside a function call or whatever we'd fall back to dynamic dispatch, right? This wouldn't be the case if we only depend on the dimensionality of the input array (which is known at compile time).
I think I really prefer not using Size
now. Alternatively we can provide a StaticBijector
or something later on? E.g. Introduce AbstractBijector
, let abstract Bijector{N} <: AbstractBijector
and abstract StaticBijector{Size} <: AbstractBijector
?
from bijectors.jl.
I think I'm actually quite fond of using the "Dimensionality" approach now, and then we can introduce StaticBijector
later as follows:
# Replaces `Bijector` as root-type
abstract type AbstractBijector end
# Redefine `Bijector{N}` as subtype of `AbstractBijector`
# => all existing code remains the same
abstract type Bijector{N} <: AbstractBijector end
# Introduce new type parameterized by `Size`
abstract type StaticBijector{Size} <: AbstractBijector end
# implement fixed-size implementations of different bijectors
...
from bijectors.jl.
StaticBijector
also sounds good to me.
from bijectors.jl.
Yes it sounds like a good plan to me.
from bijectors.jl.
I think this was mistakenly closed by the merging of PR #36 ? It's still unsolved, but an implementation and thus a fix of this issue is almost done and can be found in PR #44
from bijectors.jl.
Just to make sure; everyone is happy with this approach?
If so, I'll make the final changes to PR #44 so we can get this merged. I have a lot of work on my personal github which uses this branch as a base and it's been working really well this far.
from bijectors.jl.
Should we close this issue now that #44 has been merged? Or keep it up for future discussion? I'm still not 100% convinced the approach we've taken is absolutely optimal (though at the moment it seems like the best route)
from bijectors.jl.
I'm pro- closing this. Could always open a new issue at a later date if someone has a bright idea, but I'm in agreement that this solution feels pretty optimal for now.
from bijectors.jl.
We should also bump the version number at some point, since this is possibly a breaking change.
But there's a lot of work ready to go based on #44 being merged, so should probably get that into master before making an actual releases
from bijectors.jl.
Closed by #44
from bijectors.jl.
Related Issues (20)
- filldist, up1 not defined HOT 6
- Adding bijectors for OrderStatistic and JointOrderStatistics HOT 1
- Add API function to retrieve size of bijector output from bijector input HOT 1
- rational quadratic flows not supporting Float32 input HOT 1
- What to do with `CorrBijector` ? HOT 1
- Improve `PDVecBijector`
- Matrix factorization bijectors HOT 4
- Domain Error for VecCholeskyBijector bijector when calling logabsdetjac HOT 4
- Question on simplex bijector implementation HOT 9
- Can't apply Bijectors.ordered to TDist() and MvTDist() HOT 1
- Incorrect bijector for heterogeneous Product distribution HOT 3
- Radial flow to a simplex HOT 5
- Stackoverflow in custom bijector HOT 2
- Missing implementation of `Bijectors.bijector` for `arraydist` distributions. HOT 1
- Bijectors.ordered and MvLogNormal interaction .. only supported for unconstrained distributions. HOT 1
- `TruncatedBijectors` not defined in `Distributions` extension
- support ProductDistribution HOT 3
- Fixes to correlation bijectors
- Improve `with_logabsdet_jacobian` performance for `SimplexBijector` HOT 1
- Tests are failing for `VecCorrBijector` in _very_ rare scenarios
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bijectors.jl.