Giter Site home page Giter Site logo

Comments (17)

torfjelde avatar torfjelde commented on July 2, 2024 2

I like the StaticBijector idea!

Great! So we go for the "Dimensionality" approach now, and then we can introduce StaticBijector later on? I think StaticBijector requires some more thought and work, so personally I don't want to rush that.

from bijectors.jl.

mohamed82008 avatar mohamed82008 commented on July 2, 2024 1

I agree with Tor on this one. Here is an excerpt from Slack:

No encoding the size in the type is not good imo. It will bite us later. If the user passes x::StaticArray, we can specialize on the size of x but there is no need for the bijector to be limited to a certain size imo.

I am not sure checking size-discrepancy between bijectors in Composed upon construction is the right thing either. For instance the bijector may be a linear map with the map matrix size being runtime info.

Or a scale which makes sense regardless of the dimension.

from bijectors.jl.

xukai92 avatar xukai92 commented on July 2, 2024 1

I like the StaticBijector idea!

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

New possible solution:

# `N = 0, 1, 2, ...` represents the expected input/output (they're the same because bijection)
abstract type Bijector{N} end

# ...

# Can implement `Log` in the following way
struct Log{N} <: Bijector{N} end
Log() = Log{0}()

# can also specialize this further to catch dim-discrepancies at compile-time
(b::Log)(x) = @. log(x)

inv(b::Log{N}) where {N} = Exp{N}()

logabsdetjac(b::Log{0}, x::Real) = - log(x)

_logabsdetjac(b::Log{0}, x::Real) = - log(x)
_logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log.(x))

_logabsdetjac_batch(b::Log{0}, x::AbstractVector) =  _logabsdetjac.(b, x)
_logabsdetjac_batch(b::Log, x::AbstractArray) = mapslices(z -> _logabsdetjac(b, z), x; dims = 1)

# Can make this `Bijector{N1}` and have user override `_logabsdetjac` 
# and `_logabsdetjac_batch` instead of `logabsdetjac`. Can also
# provide default implementation of `_logabsdetjac_batch` using `mapslices`
@generated function logabsdetjac(
    b::Log{N1},
    x::AbstractArray{T2, N2}
) where {N1, T2, N2}
    if N1 == N2
        return :(_logabsdetjac(b, x))
    elseif N1 + 1 == N2
        return :(_logabsdetjac_batch(b, x))
    else
        return :(throw(MethodError(logabsdetjac, (b, x))))
    end
end

This results in

julia> using Bijectors
[ Info: Recompiling stale cache file /home/tor/.julia/compiled/v1.0/Bijectors/39uFz.ji for Bijectors [76274a88-744f-5084-9051-94815aaf08c4]

julia> using Bijectors: Log

julia> b = Log{0}()
Log{0}()

julia> b(1.0) # ✓Correctly treated as single input
0.0

julia> b([1.0, 2.0])  # ✓Correctly treated as a batch of inputs
2-element Array{Float64,1}:
 0.0               
 0.6931471805599453

julia> b = Log{1}() # expects dim 1 as output
Log{1}()

julia> b([1.0, 2.0]) # ✓Correctly transformed as a single input
2-element Array{Float64,1}:
 0.0               
 0.6931471805599453

julia> logabsdetjac(b, 1.0) # ✓Correctly not implemented
ERROR: MethodError: no method matching logabsdetjac(::Log{1}, ::Float64)
Closest candidates are:
  logabsdetjac(::Log{0}, ::Real) at /home/tor/.julia/dev/Bijectors/src/interface.jl:314
  logabsdetjac(::Inversed{#s43,N} where N where #s43<:Bijector, ::Any) at /home/tor/.julia/dev/Bijectors/src/interface.jl:77
  logabsdetjac(::ADBijector, ::Real) at /home/tor/.julia/dev/Bijectors/src/interface.jl:135
  ...
Stacktrace:
 [1] top-level scope at none:0

julia> logabsdetjac(b, [1.0, 2.0])  # ✓Correctly treated as a single input
-0.6931471805599453

julia> logabsdetjac(b, [1.0 2.0 3.0; 3.0 4.0 5.0]) # ✓Correctly treated as a batch of inputs
1×3 Array{Float64,2}:
 -1.09861  -2.07944  -2.70805

We can then do this for the rest of the bijectors. Worth noting that many of the bijectors don't care about the dimension, e.g. SimplexBijector is then simply implemented as

struct SimplexBijector{T} <: Bijector{1} where {T} end

since it only makes sense as a Bijector treating 1D vector as a single input and 2D matrix as batch of inputs.

All in all, I think it looks a bit nasty internally but for the user it's going be quite nice:)

EDIT: Another nice consequence is that we can catch-dimensionality discrepancies in Composed at compile-time, i.e. not possible to compose bijectors expecting different inputs-dimensions!

EDIT 2: The use of @generated is fairly unnecessary in the Log case:

logabsdetjac(b::Log{0}, x::Real) = -log(x)
logabsdetjac(b::Log{0}, x::AbstractVector) = -log.(x)
logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log.(x))
logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log.(x); dims = 1))

This replaces all the business of _logabsdetjac and _logabsdetjac_batch 👍

Also, in this particular case you could just use broadcasting for batches of Real but you'd still have a problem with matrix-valued, etc. where it would look similar but broadcasting wouldn't do it.

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

So I think there are two main approaches to solving this, and I'd love to hear what people think.

Size

Explicitly provide size of expected input for all bijectors you construct (you also need to do it for the ones which are only well-defined on a specific type of input, e.g. SimplexBijector).

abstract type Bijector{Size} end

# Example of construction:
# Expects input `x` with `size(x) = (3, )` 
# OR `size(x) = (3, n)` where `n` is the number of inputs ("batch-computation")
Logit{(3, )}()

# Can also make it a bit nicer by adding constructors, e.g.:
Logit(size::Tuple) = Logit{size}()

Pros

  • Catch size discrepancies for compositions of bijectors during composition (compile-time) rather than during evaluation (run-time)
  • Can do checks at compile-time when using StaticArray (but whether you're catching size discrepancies when actually evaluting a Bijector on an input at compile-time or run-time doesn't really matter as far as I can tell)
  • The upcoming Stacked bijector in #36 provides a way of stacking 0- and 1-dimensional bijectors to act on a single vector, e.g. apply Logit to x[1:3] and SimplexBijector to x[4:end] or something. With static sizes we'd be able to provide sane defaults to the ranges, e.g. stacking Logit{(N1, )} and SimplexBijector{(N2, )} will apply Logit to x[1:N1] and SimplexBijector to x[N1 + 1:end]. Can also do this for more than just two bijectors. At the moment, the user has to provide the ranges unless they're all 0-dimensional in which case we'll just do elementwise application.

Cons

  • It puts a lot of effort on the user. You need to do this explicitly for all bijectors.
  • When working with non-static arrays you'll in many cases have to reconstruct the Bijector on each input to ensure dimensionality of the Bijector is correct. This can be inefficient (and annoying) in certain cases.
  • It can also get fairly ugly internally, which might also be a problem for the user if they want to implement their own Bijector

Dimensionality

Instead of the size of the expected input we can parameterize Bijector by the dimensionality of the expected input, i.e. length(size). This is basically like Univariate and Multivariate from Distributions.jl but in full generality.

abstract type Bijector{N} end

# Example of construction:
Logit{1}() # expects vector input

Could alternatively wrap N in some type Dim, e.g. Dim{N} instead, or something like that to give N more semantic meaning.

In a lot of cases a bijector is only well-defined for a particular dimensionality, e.g. SimplexBijector is only well-defined on a 1D vector. Such cases will then subtype Bijector{specific_value_of_N}, e.g.

struct SimplexBijector <: Bijector{1} end

Pros

  • Catching dimensionality (not _size) discrepancies at composition rather than at evaluation
  • From a user-perspective, there's not much change. Only certain bijectors require the user to provide the dimensionality of the expected input since most are only well-defined on, say, a 1D vector.
  • Often easy to provide sane defaults, e.g. Log() = Log{0}().
  • Provide default constructor for Stacked in the case where the bijectors to stack are all 0-dimensional

Cons

  • For a small subset of bijectors, the user has to specify the dimensions.

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

Personally I prefer the "Dimensionality" approach 🤷‍♂️

from bijectors.jl.

xukai92 avatar xukai92 commented on July 2, 2024

Seems that for the Stacked purposes the "Size" approach is better?

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

Yeah, but I'm also afraid that Size will cause issues with type-stability due to needing run-time information for construction (unless using StaticArray, which is worse than standard Array in higher dimensions). So for Stacked it would indeed be better in the case where we construct the bijector at the top-level, but as soon as we construct it inside a function call or whatever we'd fall back to dynamic dispatch, right? This wouldn't be the case if we only depend on the dimensionality of the input array (which is known at compile time).

I think I really prefer not using Size now. Alternatively we can provide a StaticBijector or something later on? E.g. Introduce AbstractBijector, let abstract Bijector{N} <: AbstractBijector and abstract StaticBijector{Size} <: AbstractBijector?

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

I think I'm actually quite fond of using the "Dimensionality" approach now, and then we can introduce StaticBijector later as follows:

# Replaces `Bijector` as root-type
abstract type AbstractBijector end

# Redefine `Bijector{N}` as subtype of `AbstractBijector`
# => all existing code remains the same
abstract type Bijector{N} <: AbstractBijector end

# Introduce new type parameterized by `Size`
abstract type StaticBijector{Size} <: AbstractBijector end

# implement fixed-size implementations of different bijectors
...

from bijectors.jl.

mohamed82008 avatar mohamed82008 commented on July 2, 2024

StaticBijector also sounds good to me.

from bijectors.jl.

xukai92 avatar xukai92 commented on July 2, 2024

Yes it sounds like a good plan to me.

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

I think this was mistakenly closed by the merging of PR #36 ? It's still unsolved, but an implementation and thus a fix of this issue is almost done and can be found in PR #44

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

Just to make sure; everyone is happy with this approach?

If so, I'll make the final changes to PR #44 so we can get this merged. I have a lot of work on my personal github which uses this branch as a base and it's been working really well this far.

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

Should we close this issue now that #44 has been merged? Or keep it up for future discussion? I'm still not 100% convinced the approach we've taken is absolutely optimal (though at the moment it seems like the best route)

from bijectors.jl.

willtebbutt avatar willtebbutt commented on July 2, 2024

I'm pro- closing this. Could always open a new issue at a later date if someone has a bright idea, but I'm in agreement that this solution feels pretty optimal for now.

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

We should also bump the version number at some point, since this is possibly a breaking change.
But there's a lot of work ready to go based on #44 being merged, so should probably get that into master before making an actual releases

from bijectors.jl.

torfjelde avatar torfjelde commented on July 2, 2024

Closed by #44

from bijectors.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.