Giter Site home page Giter Site logo

michael-franke / faintr Goto Github PK

View Code? Open in Web Editor NEW
2.0 4.0 1.0 10.61 MB

R package 'faintr' for interpretation of BRMS model fits for data from factorial design experiment

Home Page: https://michael-franke.github.io/faintr/index.html

License: Other

R 100.00%
bayesian brms factorial-design r-package regression rstats stan contrast-coding

faintr's Introduction

faintr logo

R-CMD-check Codecov test coverage

Overview

The faintr (FActorINTerpreteR) package provides convenience functions for interpreting brms model fits for data from factorial designs. It allows for the extraction and comparison of posterior draws for a given design cell, irrespective of the encoding scheme used in the model.

Currently, faintr provides the following functions:

  • get_cell_definitions returns information on the predictor variables and how they are encoded in the model.
  • extract_cell_draws returns posterior draws and additional metadata for all design cells.
  • filter_cell_draws returns posterior draws and additional metadata for one subset of design cells.
  • compare_groups returns summary statistics of comparing two subsets of design cells.

Installation

You can install the development version from GitHub with:

# install.packages("devtools")
devtools::install_github("michael-franke/faintr")

Examples

In this section, we shortly introduce how to use the package. For a more detailed overview, please refer to the vignette.

We will use a preprocessed version of the mouse-tracking data set from the aida package:

data %>% 
  select(RT, group, condition, prototype_label) %>%
  head()
#> # A tibble: 6 x 4
#>      RT group condition prototype_label
#>   <dbl> <chr> <chr>     <fct>          
#> 1   950 touch Atypical  straight       
#> 2  1251 touch Typical   straight       
#> 3   930 touch Atypical  curved         
#> 4   690 touch Atypical  curved         
#> 5   951 touch Typical   CoM            
#> 6  1079 touch Atypical  CoM

The variables relevant for us are:

  • RT: Reaction time in milliseconds
  • group: Whether a category is selected by click vs touch
  • condition: Whether the animal is a typical vs atypical representative of its category
  • prototype_label: The type of prototypical movement strategy (straight vs curved vs CoM)

Below, we regress the log-transformed reaction times as a function of factors group, condition, prototype_label, and their three-way interaction using a linear regression model fitted with brms:

fit <- brms::brm(formula = log(RT) ~ group * condition * prototype_label,
                 data = data,
                 seed = 123
                 )

To obtain information on the factors and the coding scheme used in the model, we can use get_cell_definitions:

get_cell_definitions(fit)
#> # A tibble: 12 x 16
#>     cell group condition prototype_label Intercept grouptouch conditionTypical
#>    <int> <chr> <chr>     <fct>               <dbl>      <dbl>            <dbl>
#>  1     1 touch Atypical  straight                1          1                0
#>  2     2 touch Typical   straight                1          1                1
#>  3     3 touch Atypical  curved                  1          1                0
#>  4     4 touch Typical   CoM                     1          1                1
#>  5     5 touch Atypical  CoM                     1          1                0
#>  6     6 touch Typical   curved                  1          1                1
#>  7     7 click Atypical  straight                1          0                0
#>  8     8 click Typical   straight                1          0                1
#>  9     9 click Typical   curved                  1          0                1
#> 10    10 click Atypical  CoM                     1          0                0
#> 11    11 click Typical   CoM                     1          0                1
#> 12    12 click Atypical  curved                  1          0                0
#> # ... with 9 more variables: prototype_labelcurved <dbl>,
#> #   prototype_labelCoM <dbl>, `grouptouch:conditionTypical` <dbl>,
#> #   `grouptouch:prototype_labelcurved` <dbl>,
#> #   `grouptouch:prototype_labelCoM` <dbl>,
#> #   `conditionTypical:prototype_labelcurved` <dbl>,
#> #   `conditionTypical:prototype_labelCoM` <dbl>,
#> #   `grouptouch:conditionTypical:prototype_labelcurved` <dbl>, ...

The output shows that factors group, condition and prototype_label are dummy-coded, with click, Atypical, and straight being the reference levels, respectively.

To extract posterior draws for all design cells, we can use extract_cell_draws:

extract_cell_draws(fit)
#> # A draws_df: 1000 iterations, 4 chains, and 12 variables
#>    touch:Atypical:straight touch:Typical:straight touch:Atypical:curved
#> 1                      7.4                    7.2                   7.5
#> 2                      7.4                    7.2                   7.4
#> 3                      7.4                    7.2                   7.5
#> 4                      7.4                    7.1                   7.4
#> 5                      7.4                    7.2                   7.6
#> 6                      7.4                    7.2                   7.4
#> 7                      7.4                    7.2                   7.5
#> 8                      7.4                    7.2                   7.5
#> 9                      7.4                    7.2                   7.4
#> 10                     7.4                    7.2                   7.5
#>    touch:Typical:CoM touch:Atypical:CoM touch:Typical:curved
#> 1                7.6                7.6                  7.2
#> 2                7.5                7.6                  7.1
#> 3                7.5                7.6                  7.1
#> 4                7.4                7.7                  7.1
#> 5                7.5                7.4                  7.1
#> 6                7.5                7.5                  7.1
#> 7                7.4                7.7                  7.2
#> 8                7.4                7.6                  7.1
#> 9                7.5                7.7                  7.2
#> 10               7.5                7.6                  7.2
#>    click:Atypical:straight click:Typical:straight
#> 1                      7.6                    7.4
#> 2                      7.6                    7.4
#> 3                      7.6                    7.4
#> 4                      7.6                    7.4
#> 5                      7.6                    7.4
#> 6                      7.6                    7.4
#> 7                      7.7                    7.4
#> 8                      7.7                    7.4
#> 9                      7.6                    7.4
#> 10                     7.6                    7.4
#> # ... with 3990 more draws, and 4 more variables
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}

With filter_cell_draws we can obtain posterior draws for a specific design cell. For instance, draws for typical exemplars in click trials, averaged over factor prototype_label, can be extracted like so:

filter_cell_draws(fit, condition == "Typical" & group == "click")
#> # A draws_df: 1000 iterations, 4 chains, and 1 variables
#>    draws
#> 1    7.4
#> 2    7.4
#> 3    7.5
#> 4    7.5
#> 5    7.5
#> 6    7.4
#> 7    7.5
#> 8    7.4
#> 9    7.4
#> 10   7.5
#> # ... with 3990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}

Parameter colname allows changing the default column name in the output, which facilitates post-processing of cell draws, e.g., for plotting or summary statistics. Here, we extract the draws for each level of prototype_label (averaged over group and condition) and visualize the results:

draws_straight <- filter_cell_draws(fit, prototype_label == "straight", colname = "straight")
draws_curved <- filter_cell_draws(fit, prototype_label == "curved", colname = "curved")
draws_CoM <- filter_cell_draws(fit, prototype_label == "CoM", colname = "CoM")

draws_prototype <- posterior::bind_draws(draws_straight, draws_curved, draws_CoM) %>%
  pivot_longer(cols = posterior::variables(.), names_to = "prototype", values_to = "value")

draws_prototype %>%
  ggplot(aes(x = value, color = prototype, fill = prototype)) +
  geom_density(alpha = 0.4)

Finally, we can compare two subsets of design cells with compare_groups. Here, we compare the estimates for atypical exemplars in click trials against typical exemplars in click trials (averaged over the three prototypical movement strategies):

compare_groups(fit,
               higher = condition == "Atypical" & group == "click",
               lower = condition == "Typical" & group == "click"
               )
#> Outcome of comparing groups: 
#>  * higher:  condition == "Atypical" & group == "click" 
#>  * lower:   condition == "Typical" & group == "click" 
#> Mean 'higher - lower':  0.2215 
#> 95% HDI:  [ 0.1421 ; 0.2978 ]
#> P('higher - lower' > 0):  1 
#> Posterior odds:  Inf

If one of two group specifications is left out, we compare against the grand mean:

compare_groups(fit,
               higher = group == "click"
               )
#> Outcome of comparing groups: 
#>  * higher:  group == "click" 
#>  * lower:   grand mean 
#> Mean 'higher - lower':  0.1009 
#> 95% HDI:  [ 0.06956 ; 0.1302 ]
#> P('higher - lower' > 0):  1 
#> Posterior odds:  Inf

If the Boolean flag include_bf is set to TRUE (default is FALSE), Bayes Factors for the inequality (higher > lower) are approximated in comparison to the “negated hypothesis” (lower <= higher). However, this requires specifying proper priors for all parameters:

fit_with_priors <- brms::brm(formula = log(RT) ~ group * condition * prototype_label,
                             prior = prior(student_t(1, 0, 3), class = "b"),
                             data = data,
                             seed = 123
                             )
compare_groups(fit_with_priors,
               higher = prototype_label != "straight",
               lower = prototype_label == "straight",
               include_bf = TRUE
               )
#> Outcome of comparing groups: 
#>  * higher:  prototype_label != "straight" 
#>  * lower:   prototype_label == "straight" 
#> Mean 'higher - lower':  0.1062 
#> 95% HDI:  [ 0.05464 ; 0.1547 ]
#> P('higher - lower' > 0):  0.9998 
#> Posterior odds:  3999 
#> Bayes factor:  4015

faintr's People

Contributors

michael-franke avatar ooezenoglu avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

n-kall

faintr's Issues

`extract_cell_draws`: Output in long format

Currently, extract_cell_draws returns the draws in wide format:

> extract_cell_draws(fit)

# A draws_df: 1000 iterations, 4 chains, and 4 variables
   F:pol F:inf M:pol M:inf
1    236   264   156   153
2    235   269   135   148
3    235   265   151   168
4    197   226   146   168
5    211   229   169   175
6    164   208   156   170
7    258   273   111   142
8    286   298   180   192
9    215   251   214   222
10   231   263   189   206
# ... with 3990 more draws
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}

I think it would be useful to also have an option to get the draws in long format (e.g., for easier plotting or SumStats):

# A draws_df: 1000 iterations, 4 chains, and 3 variables
   gender context draw
1       F     pol   236
2       F     pol   235
3       F     pol   235
4       F     pol   197
5       F     pol   211
6       F     pol   164
7       F     pol   258
8       F     pol   286
9       F     pol   215
10      F     pol   231
# ... with 15990 more draws
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}

We could add a parameter format with "long" as default (?).

[feature] support rstanarm models

I think we can probably also support rstanarm models. We can extract the design matrix with rstanarm::get_x().
I'll look into this further, but I don't think there is anything strictly necessary for faintr that we use in brms that is missing in rstanarm.

Possibly add easy way to extract draws for ALL cells w/o manual specification

I might be overlooking something, but I don't think that there is a simple way to obtain draws for ALL cells, while this is probably a rather common use case (e.g., for plotting / global inspection). As far as I can see, the user would have to call "extract_cell_draws" for every single cell, despite the fact that "extract_cell_draws" internally computes draws for ALL cells on each call. That's also rather inefficient.

I suggest that we include an option for "extract_cell_draws" (e.g., parameter "group = 'all'") to return the full list. Alternatively, we could make that the default and return the draws from the grand mean for something like parameter "group = "grand_mean"). I'd imagine that wanting ALL samples is more frequent than wanting samples for the grand mean, so the former would be better for the user.

Showcase return typ of "compare_groups' in docs

The README and the vignette only shows the print display of the object returned by "compare_groups". It would likely be helpful to show how individual summary stats can be extracted from the return object (lest readers think that they have to manually or automatically extract the relevant information from the shown string format).

Restructure `faintCompare` object such that it's easier to handle

Without further wrangling, the current structure of the faintCompare object only allows extracting one summary statistic at a time. I think it would be more convenient if we stored the values in a data frame so that extraction can be done in one go. Also, the object does not contain the posterior samples used for comparison. While not strictly necessary for the purpose of the function, we might still want to include them to save additional calls to extract_cell_draws in case the samples are actually needed (e.g., for plotting or testing).

The object could have this structure:

> x <- compare_groups(fit, gender == "F", gender == "M")
> x %>% str()
List of 3
 $ hdi       : num 0.95
 $ comparison:'data.frame':	1 obs. of  7 variables:
  ..$ higher   : chr "gender == \"F\""
  ..$ lower    : chr "gender == \"M\""
  ..$ mean_diff: num 108
  ..$ l_ci     : num 95.4
  ..$ u_ci     : num 122
  ..$ post_prob: num 1
  ..$ post_odds: num Inf
 $ samples   :List of 2
  ..$ higher: draws_df [400 x 4] (S3: draws_df/draws/tbl_df/tbl/data.frame)
  .. ..$ draws     : num [1:400] 251 248 246 249 245 ...
  .. ..$ .chain    : int [1:400] 1 1 1 1 1 1 1 1 1 1 ...
  .. ..$ .iteration: int [1:400] 1 2 3 4 5 6 7 8 9 10 ...
  .. ..$ .draw     : int [1:400] 1 2 3 4 5 6 7 8 9 10 ...
  ..$ lower : draws_df [400 x 4] (S3: draws_df/draws/tbl_df/tbl/data.frame)
  .. ..$ draws     : num [1:400] 134 150 132 143 143 ...
  .. ..$ .chain    : int [1:400] 1 1 1 1 1 1 1 1 1 1 ...
  .. ..$ .iteration: int [1:400] 1 2 3 4 5 6 7 8 9 10 ...
  .. ..$ .draw     : int [1:400] 1 2 3 4 5 6 7 8 9 10 ...
 - attr(*, "class")= chr "faintCompare"

Reuse calculations in `compare_groups`

Currently, computing the summary statistics in compare_groups is quite repetitive. Reusing previous calculations would be more efficient here, especially for a large number of draws.

Possibly use arrays instead of data frames

I noticed that some of the code could be simplified if we use arrays / matrices instead of first converting things to data frames. It's also possible we could use draws_array or draws_matrix instead of draws_df to simplify some of the multiplication operations.

If you think this is worth changing and don't see any downsides, I can make a PR.

e.g.

cell_defs <- dplyr::bind_cols(
    fit$data %>% dplyr::select(dplyr::all_of(fixef)),
    as.data.frame(brms::standata(fit)$X)
  ) %>% unique()

can be simplified to

cell_defs <- dplyr::bind_cols(
    fit$data %>% dplyr::select(dplyr::all_of(fixef)),
    brms::standata(fit)$X
  ) %>% unique()

or without dplyr:

cell_defs <- fit$data[fixef] %>%
    cbind(brms::standata(fit)$X) %>%
    unique()

Add more unit tests

  • test get_cell_definitions, filter_cell_draws, and compare_groups more thoroughly
  • add tests for extract_cell_draws

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.