arviz-devs / psis.jl Goto Github PK
View Code? Open in Web Editor NEWPareto smoothed importance sampling
Home Page: https://julia.arviz.org/PSIS
License: MIT License
Pareto smoothed importance sampling
Home Page: https://julia.arviz.org/PSIS
License: MIT License
Sometimes a user may want to diagnose whether importance sampling will work without smoothing the weights. This is especially useful when computing expectation-specific diagnostics (see #21).
For testing and visualization purposes, it would be useful to return a PSISResult
object that contains intermediate computed quantities. This would contain the entire fitted GeneralizedPareto
distribution, overload getproperty
to provide weights (instead of just log-weights), and overload show
to provide a more useful visualization of diagnostics.
We could also then implement Makie and Plots type recipes for plotting the results object, which would look like ArviZ.jl's Pareto-shape plot: https://arviz-devs.github.io/ArviZ.jl/stable/mpl_examples/#Pareto-Shape-Plot. This would ultimately replace the plot_khat
plotting function in ArviZ.jl. Instead, a user would call plot(psis(log_ratios))
when running PSIS.
There are a number of methods in the literature for fitting the generalized Pareto distribution. We currently implement both the methods of Zhang & Stephens, 2009 (used in the PSIS paper) and Zhang, 2010, which makes some improvements for k>1
. Switching between these two is supported by the improved
keyword. This is not well named, as it gives the impression that the Zhang, 2010 method is universally better (for k ∈ (0, 1)
, the range for which smoothing can help the most, it generally is not). These methods are relatively simple to implement and work well; some others would require heavier dependencies like Optim.jl and may be more costly.
It would be preferable to support easily adding new fitting methods, via something like a gpd_fit_method
keyword, which might take a symbol or a singleton object as value. Internally, we should dispatch on this keyword.
In the last 24 hours the reference tests began failing. This is happening because a new method for randomly generating normally distributed numbers was introduced in (JuliaStats/Distributions.jl#1680), defined at https://github.com/JuliaStats/Distributions.jl/blob/221a9e801c7cf5280f85b2c893862688c60322d8/src/univariate/continuous/normal.jl#L110
We can just regenerate the references. It would be even better if we could compare to a reference implementation, e.g. loo's.
StyledStrings is in Julia v1.11 part of the standard library, but it can be installed separately for older Julia versions. It would allows us to define named "faces" (styles) for the different k-hat categories, which dependents could then use. I think in principle this would allow our colors to appear in HTML and Markdown representations, though currently not many packages have implemented support for styled strings.
Unless we can figure out a way to make this support an extension, on v1.10 and older, having StyledStrings as a direct dependency will increase package load times by ~1s.
When all log-ratios are low enough that their exponential is exactly 0, then for the returned GeneralizedPareto
object, the corresponding σ
value will be zero, even if the actual σ
fit to the tail is non-zero. e.g
julia> using PSIS
julia> psis(randn(1000) .+ -1500)
ERROR: ArgumentError: GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Distributions/K939H/src/utils.jl:6 [inlined]
[2] #GeneralizedPareto#147
@ ~/.julia/packages/Distributions/K939H/src/univariate/continuous/generalizedpareto.jl:41 [inlined]
We could construct GeneralizedPareto
with check_args=false
to silence the error, but this still presents a useless object. Instead of returning the GeneralizedPareto
fit to the provided log-ratios, we should return one of the following:
μ=0
)I suspect the 2nd option is more useful in the end to the user, though we should then also include in PSISResult
any necessary intermediates for the user to reconstruct the tail weights themselves.
Currently r_eff
is a required 2nd argument to psis
and psis!
. We should consider making it an optional argument. I can think of two ways to do this:
r_eff=ones(nparams)
psis
on MCMC-generated log-ratios without bothering to read the docs.r_eff
using ess_rhat
from MCMCDiagnosticTools, potentially using RequiresI'm leaning towards (1).
If any NaN
or Inf
are present in the log-ratios vector, then an error is raised by the GPD distribution. Same happens if all log-ratios are -Inf
. For all of these cases, it would be more useful to early detect and raise a warning that PSIS fails instead of raising an error.
Another case is when too many elements of the upper tail vector are -Inf
. Then, likewise, an error is raised, and it's not immediately clear to me why that should be.
julia> using PSIS
julia> psis(fill(1, 100)) # repeat values is fine
PSISResult with 1 parameters, 100 draws, and 1 chains
Pareto shape (k) diagnostic values:
Count Min. ESS
(-Inf, 0.5] good 1 (100.0%) 100
julia> psis(fill(NaN, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...
julia> psis(fill(Inf, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...
julia> psis(fill(-Inf, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...
julia> psis_result = psis(vcat(ones(50), fill(-Inf, 435)))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...
As pointed out in stan-dev/loo#185, we only need the last M
tail indices, so instead of sorting the whole array of weights, we can just sort the upper tail weights, which at least for Vector
is faster. We can do this by replacing sortperm
with partialsortperm
.
Currently we depend on Distributions so that we can return the fit generalized Pareto distribution as a Distributions object for convenience, but ~75% of our load time ends up being loading Distributions, and we don't even use its functionality. Perhaps we should instead offer a utility function convert(::Type{Distributions.Distribution}, ::PSIS.GeneralizedPareto)
that is conditionally loaded if Distributions is in scope so a user can easily get a Distribution if they want one, but we can make the package even more lightweight.
Similarly ~18% of our load time is due to PrettyTables, but we only use a tiny amount of its functionality. Perhaps we should manually create our own function for pretty-printing PSISResult
, and then we can drop PrettyTables as a dependency. It seems this would get us down to 5% of our current loadtime.
The PSIS paper notes it is sometimes useful to compute h
-specific (function/expectation-specific) diagnostics, which more specifically diagnose the quality of importance sampling for the expectation in question. From discussions with @avehtari, this can be done with the following procedure. Given importance ratios r(θ)
and function evaluations h(θ)
, we:
w(θ)
as r(θ)
with a Pareto-smoothed upper tail.v(θ) = h(θ) r(θ)
. Fit the GPD to upper tails of both v(θ)
and -v(θ)
to get kₕᵤ
and kₕₗ
. Compute kₕ=max(kₕᵤ, kₕₗ)
.k > 0.7
or kₕ > 0.7
w(θ)
, k
, and kₕ
. The user would then estimate 𝔼[h]
as sum(h .* w) ./ sum(w)
.Likewise, we can return h
-specific ESS estimates by simply replacing w
in the ESS computation with v
.
A few things to consider:
PSISResult
currently can't accommodate 3 smoothing results. Perhaps it should be extended, or a different object should be returned if h
is provided.h
's simultaneously could be more useful. Perhaps the API should support both.v
. Between this feature and that, some changes may need to be made to the internals.Currently, if we use PSIS for a single parameter, we get the following output:
PSISResult with 1 parameters, 1000 draws, and 1 chains
Pareto shape (k) diagnostic values:
Count Min. ESS
(-Inf, 0.5] good 1 (100.0%) 363
The downside here is that the actual shape value is not shown, and the Count
column is not informative. Probably something like this would be more useful:
PSISResult with 1 parameters, 1000 draws, and 1 chains
Pareto shape (k) diagnostic value:
Shape k ESS
good (-Inf, 0.5] 0.3 363
OR
PSISResult with 1 parameters, 1000 draws, and 1 chains
Shape diagnostic k: 0.3 ∈ (-Inf, 0.5] (good)
ESS: 363
https://arxiv.org/abs/1906.08850 introduced an adaptive importance sampling method called Importance Weighted Moment Matching. Given a target distribution p
, it takes Monte Carlo samples θ
from a distribution q
, the log density function of q
, and a function h
whose expectation one wants to take. It then alternates between transforming the Monte Carlo samples to θ*
(effectively, modifying the proposal distribution) and computing the Pareto shape diagnostic k
until k ≤ 0.7
. The transformations used are affine and are chosen to match the first two moments of the proposal distribution to the first 2 importance-weighted moments.
The method returns either the estimated expectation 𝔼ₚ[h(θ)]
or, perhaps, h(θ*)
and w(θ*)
, as well as the shape diagnostic. It would be handy also to return the sequence of affine transformations (or their composition) used. Like #21, this motivates some rethinking of the API.
This issue is used to trigger TagBot; feel free to unsubscribe.
If you haven't already, you should update your TagBot.yml
to include issue comment triggers.
Please see this post on Discourse for instructions and more details.
If you'd like for me to do this for you, comment TagBot fix
on this issue.
I'll open a PR within a few hours, please be patient!
It appears that our CI is now failing on the Makie integration code. See https://github.com/arviz-devs/PSIS.jl/runs/8046708122?check_suite_focus=true for details.
We should consider making Plots and Makie integration subpackages. This would allow us to specify the compat bounds, and we could also remove our backend switching mechanism.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.