Giter Site home page Giter Site logo

psis.jl's People

Contributors

dependabot[bot] avatar github-actions[bot] avatar sethaxen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

mfleader

psis.jl's Issues

Allow diagnosing without smoothing

Sometimes a user may want to diagnose whether importance sampling will work without smoothing the weights. This is especially useful when computing expectation-specific diagnostics (see #21).

Return a results object

For testing and visualization purposes, it would be useful to return a PSISResult object that contains intermediate computed quantities. This would contain the entire fitted GeneralizedPareto distribution, overload getproperty to provide weights (instead of just log-weights), and overload show to provide a more useful visualization of diagnostics.

We could also then implement Makie and Plots type recipes for plotting the results object, which would look like ArviZ.jl's Pareto-shape plot: https://arviz-devs.github.io/ArviZ.jl/stable/mpl_examples/#Pareto-Shape-Plot. This would ultimately replace the plot_khat plotting function in ArviZ.jl. Instead, a user would call plot(psis(log_ratios)) when running PSIS.

Support testing other GPD fitting methods

There are a number of methods in the literature for fitting the generalized Pareto distribution. We currently implement both the methods of Zhang & Stephens, 2009 (used in the PSIS paper) and Zhang, 2010, which makes some improvements for k>1. Switching between these two is supported by the improved keyword. This is not well named, as it gives the impression that the Zhang, 2010 method is universally better (for k ∈ (0, 1), the range for which smoothing can help the most, it generally is not). These methods are relatively simple to implement and work well; some others would require heavier dependencies like Optim.jl and may be more costly.

It would be preferable to support easily adding new fitting methods, via something like a gpd_fit_method keyword, which might take a symbol or a singleton object as value. Internally, we should dispatch on this keyword.

Reference tests failing on CI

In the last 24 hours the reference tests began failing. This is happening because a new method for randomly generating normally distributed numbers was introduced in (JuliaStats/Distributions.jl#1680), defined at https://github.com/JuliaStats/Distributions.jl/blob/221a9e801c7cf5280f85b2c893862688c60322d8/src/univariate/continuous/normal.jl#L110

We can just regenerate the references. It would be even better if we could compare to a reference implementation, e.g. loo's.

Use StyledStrings

StyledStrings is in Julia v1.11 part of the standard library, but it can be installed separately for older Julia versions. It would allows us to define named "faces" (styles) for the different k-hat categories, which dependents could then use. I think in principle this would allow our colors to appear in HTML and Markdown representations, though currently not many packages have implemented support for styled strings.

Unless we can figure out a way to make this support an extension, on v1.10 and older, having StyledStrings as a direct dependency will increase package load times by ~1s.

Constructing fit tailed distribution fails

When all log-ratios are low enough that their exponential is exactly 0, then for the returned GeneralizedPareto object, the corresponding σ value will be zero, even if the actual σ fit to the tail is non-zero. e.g

julia> using PSIS

julia> psis(randn(1000) .+ -1500)
ERROR: ArgumentError: GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Distributions/K939H/src/utils.jl:6 [inlined]
  [2] #GeneralizedPareto#147
    @ ~/.julia/packages/Distributions/K939H/src/univariate/continuous/generalizedpareto.jl:41 [inlined]

We could construct GeneralizedPareto with check_args=false to silence the error, but this still presents a useless object. Instead of returning the GeneralizedPareto fit to the provided log-ratios, we should return one of the following:

  • the corresponding fit to the normalized weights
  • the corresponding fit to the tail weights (i.e. after shifting so that μ=0)

I suspect the 2nd option is more useful in the end to the user, though we should then also include in PSISResult any necessary intermediates for the user to reconstruct the tail weights themselves.

Make `r_eff` an optional argument

Currently r_eff is a required 2nd argument to psis and psis!. We should consider making it an optional argument. I can think of two ways to do this:

  1. default to r_eff=ones(nparams)
  • pros: easy, convenient, no additional dependencies, consistency with ArviZ.jl
  • cons: potentially promotes mis-usage. e.g. a user could just run psis on MCMC-generated log-ratios without bothering to read the docs.
  1. Compute r_eff using ess_rhat from MCMCDiagnosticTools, potentially using Requires
  • pros: the correct thing is done by default
  • cons: either adds a dependency or makes the method not always available, wasteful

I'm leaning towards (1).

Errors raised with non-finite numbers

If any NaN or Inf are present in the log-ratios vector, then an error is raised by the GPD distribution. Same happens if all log-ratios are -Inf. For all of these cases, it would be more useful to early detect and raise a warning that PSIS fails instead of raising an error.

Another case is when too many elements of the upper tail vector are -Inf. Then, likewise, an error is raised, and it's not immediately clear to me why that should be.

julia> using PSIS

julia> psis(fill(1, 100))  # repeat values is fine
PSISResult with 1 parameters, 100 draws, and 1 chains
Pareto shape (k) diagnostic values:
                    Count       Min. ESS 
 (-Inf, 0.5]  good  1 (100.0%)  100

julia> psis(fill(NaN, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...

julia> psis(fill(Inf, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...

julia> psis(fill(-Inf, 100))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...

julia> psis_result = psis(vcat(ones(50), fill(-Inf, 435)))
ERROR: DomainError with NaN:
GeneralizedPareto: the condition σ > zero(σ) is not satisfied.
...

Only sort the upper tail weights

As pointed out in stan-dev/loo#185, we only need the last M tail indices, so instead of sorting the whole array of weights, we can just sort the upper tail weights, which at least for Vector is faster. We can do this by replacing sortperm with partialsortperm.

Reducing dependencies

Currently we depend on Distributions so that we can return the fit generalized Pareto distribution as a Distributions object for convenience, but ~75% of our load time ends up being loading Distributions, and we don't even use its functionality. Perhaps we should instead offer a utility function convert(::Type{Distributions.Distribution}, ::PSIS.GeneralizedPareto) that is conditionally loaded if Distributions is in scope so a user can easily get a Distribution if they want one, but we can make the package even more lightweight.

Similarly ~18% of our load time is due to PrettyTables, but we only use a tiny amount of its functionality. Perhaps we should manually create our own function for pretty-printing PSISResult, and then we can drop PrettyTables as a dependency. It seems this would get us down to 5% of our current loadtime.

Add convenience function for computing expectation-specific diagnostics

The PSIS paper notes it is sometimes useful to compute h-specific (function/expectation-specific) diagnostics, which more specifically diagnose the quality of importance sampling for the expectation in question. From discussions with @avehtari, this can be done with the following procedure. Given importance ratios r(θ) and function evaluations h(θ), we:

  1. Compute w(θ) as r(θ) with a Pareto-smoothed upper tail.
  2. Compute v(θ) = h(θ) r(θ). Fit the GPD to upper tails of both v(θ) and -v(θ) to get kₕᵤ and kₕₗ. Compute kₕ=max(kₕᵤ, kₕₗ).
  3. Warn for either k > 0.7 or kₕ > 0.7
  4. Return w(θ), k, and kₕ. The user would then estimate 𝔼[h] as sum(h .* w) ./ sum(w).

Likewise, we can return h-specific ESS estimates by simply replacing w in the ESS computation with v.

A few things to consider:

  1. Providing this requires decoupling diagnosing from smoothing. i.e. for some uses one might just want to fit the GPD and check the diagnostic without modifying any weights. (Edit: see #26 26)
  2. PSISResult currently can't accommodate 3 smoothing results. Perhaps it should be extended, or a different object should be returned if h is provided.
  3. Currently the API assumes users sample from the proposal distribution once and then smooth for many different target distributions simultaneously. While this is what is needed for LOO, I suspect that diagnosing for many different h's simultaneously could be more useful. Perhaps the API should support both.
  4. RE (2), since Monte Carlo is just a special case of IS where all weights are uniform, the same function could be used to diagnose all univariate marginals of a posterior sample to identify problems with estimating means and variances of parameters. This should be supported.
  5. the moment-matching paper (https://arxiv.org/abs/1906.08850, see #25) uses a similar formulation with v. Between this feature and that, some changes may need to be made to the internals.

More useful printed result for a single parameter

Currently, if we use PSIS for a single parameter, we get the following output:

PSISResult with 1 parameters, 1000 draws, and 1 chains
Pareto shape (k) diagnostic values:
                    Count       Min. ESS 
 (-Inf, 0.5]  good  1 (100.0%)  363

The downside here is that the actual shape value is not shown, and the Count column is not informative. Probably something like this would be more useful:

PSISResult with 1 parameters, 1000 draws, and 1 chains
Pareto shape (k) diagnostic value:
                   Shape k  ESS 
good  (-Inf, 0.5]  0.3      363

OR

PSISResult with 1 parameters, 1000 draws, and 1 chains
    Shape diagnostic k: 0.3  (-Inf, 0.5]  (good)
    ESS: 363

Importance weighted moment matching

https://arxiv.org/abs/1906.08850 introduced an adaptive importance sampling method called Importance Weighted Moment Matching. Given a target distribution p, it takes Monte Carlo samples θ from a distribution q, the log density function of q, and a function h whose expectation one wants to take. It then alternates between transforming the Monte Carlo samples to θ* (effectively, modifying the proposal distribution) and computing the Pareto shape diagnostic k until k ≤ 0.7. The transformations used are affine and are chosen to match the first two moments of the proposal distribution to the first 2 importance-weighted moments.

The method returns either the estimated expectation 𝔼ₚ[h(θ)] or, perhaps, h(θ*) and w(θ*), as well as the shape diagnostic. It would be handy also to return the sequence of affine transformations (or their composition) used. Like #21, this motivates some rethinking of the API.

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.