Giter Site home page Giter Site logo

Comments (8)

devmotion avatar devmotion commented on June 9, 2024

While it's not documented, it's used in many core packages (eg StatsBase, Distributions, ...), so breaking it would break other parts of the pipeline already. The main advantage of broadcasted (if you use instantiate!) is that the sum will use pairwise summation and it's fast. If you use zip or other iterators no pairwise summation is performed.

from gplikelihoods.jl.

simsurace avatar simsurace commented on June 9, 2024

Interesting. Have you benchmarked this recently?

function my_expected_loglikelihood(
    gh::GPLikelihoods.GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
     return sum(expected_loglikelihood(gh, lik, q_fᵢ, yᵢ) for (q_fᵢ, yᵢ) in zip(q_f, y))
end

function my_expected_loglikelihood2(
    gh::GPLikelihoods.GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
    return mapfoldl(qfy -> expected_loglikelihood(gh, lik, qfy[2], qfy[1]), +, zip(y, q_f))
end

function my_expected_loglikelihood3(
    gh::GPLikelihoods.GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
    return mapreduce(qfy -> expected_loglikelihood(gh, lik, qfy[2], qfy[1]), +, zip(y, q_f))
end
julia> @btime expected_loglikelihood(GPLikelihoods.default_expectation_method($gp.lik), $gp.lik, $q_f, $y)
  1.426 ms (34 allocations: 13.34 KiB)
julia> @btime my_expected_loglikelihood(GPLikelihoods.default_expectation_method($gp.lik), $gp.lik, $q_f, $y)
  1.486 ms (3039 allocations: 91.67 KiB)
julia> @btime my_expected_loglikelihood2(GPLikelihoods.default_expectation_method($gp.lik), $gp.lik, $q_f, $y)
  1.422 ms (34 allocations: 13.34 KiB)
julia> @btime my_expected_loglikelihood3(GPLikelihoods.default_expectation_method($gp.lik), $gp.lik, $q_f, $y)
  1.421 ms (34 allocations: 13.34 KiB)

So with my_expected_loglikelihood there are significantly more allocations, but the performance of the current implementation can be reproduced with mapfoldl or mapreduce.

EDIT:

julia> versioninfo()
Julia Version 1.7.2
Commit bf53498635 (2022-02-06 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-9700K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)

from gplikelihoods.jl.

devmotion avatar devmotion commented on June 9, 2024

The main point is pairwise summation - while still being fast. zip does not use pairwise summation.

from gplikelihoods.jl.

devmotion avatar devmotion commented on June 9, 2024

An example that illustrates my point:

julia> using BenchmarkTools, LinearAlgebra

julia> A = vcat(1f0, fill(1f-8, 10^8));

julia> w = ones(Float32, length(A));

julia> f_dot(A, w) = dot(A, w)
f_dot (generic function with 1 method)

julia> f_sum_zip(A, w) = sum(Ai * wi for (Ai, wi) in zip(A, w))
f_sum_zip (generic function with 1 method)

julia> f_sum_broadcasted(A, w) = sum(Broadcast.instantiate(Broadcast.broadcasted(*, A, w)))
f_sum_broadcasted (generic function with 1 method)

julia> f_sum_mapfoldl(A, w) = mapfoldl(+, zip(A, w)) do (Ai, wi)
           Ai * wi
       end
f_sum_mapfoldl (generic function with 1 method)

julia> f_sum_mapreduce(A, w) = mapreduce(+, zip(A, w)) do (Ai, wi)
           Ai * wi
       end
f_sum_mapreduce (generic function with 1 method)

julia> f_dot(A, w)
1.9625133f0

julia> f_sum_zip(A, w)
1.0f0

julia> f_sum_broadcasted(A, w)
1.9999989f0

julia> f_sum_mapfoldl(A, w)
1.0f0

julia> f_sum_mapreduce(A, w)
1.0f0

julia> @btime f_dot($A, $w);
  42.887 ms (0 allocations: 0 bytes)

julia> @btime f_sum_zip($A, $w);
  140.985 ms (0 allocations: 0 bytes)

julia> @btime f_sum_broadcasted($A, $w);
  46.262 ms (0 allocations: 0 bytes)

julia> @btime f_sum_mapfoldl($A, $w);
  141.182 ms (0 allocations: 0 bytes)

julia> @btime f_sum_mapreduce($A, $w);
  141.659 ms (0 allocations: 0 bytes)

from gplikelihoods.jl.

simsurace avatar simsurace commented on June 9, 2024

Very nice, thanks. Strangely, sum(A) seems to give the result of the broadcasted sum above. Apparently it uses pairwise summation as well. I wonder why it doesn't when iterating over a zip. Would it make sense to have a pairwise iterator in Base? This seems like such a common pattern.

from gplikelihoods.jl.

devmotion avatar devmotion commented on June 9, 2024

Yes, sum on arrays uses also pairwise summation. mapreduce is specialized and uses pairwise operations for AbstractArray and Broadcasted (https://github.com/JuliaLang/julia/blob/bf534986350a991e4a1b29126de0342ffd76205e/base/reduce.jl#L235-L257), and sum is just mapreduce(op, add_sum, ...) under the hood.

from gplikelihoods.jl.

simsurace avatar simsurace commented on June 9, 2024

I see. It seems that something that zips multiple vectors but still allows linear indexing would accomplish the same thing as Broadcast.broadcasted. It's trivial to modify the linked mapreduce_impl to support two input arrays, but mapreduce(*, +, A, w) instead calls a different method which uses Base.Generator and seems to allocate. It seems that if pairwise summation is generally preferable and is the default for sum(::AbstractArray), it should be for other reductions with + as well.

from gplikelihoods.jl.

simsurace avatar simsurace commented on June 9, 2024

I can't figure out the reason for expected_loglikelihood being type-unstable though. That does not happen in other examples as f_sum_broadcasted above.

from gplikelihoods.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.