Giter Site home page Giter Site logo

ledell / cvauc Goto Github PK

View Code? Open in Web Editor NEW
23.0 4.0 11.0 141 KB

Computationally efficient confidence intervals for cross-validated AUC estimates in R

License: Apache License 2.0

R 100.00%
cross-validation confidence-intervals auc machine-learning statistics r variance

cvauc's Introduction

cvAUC

The cvAUC R package provides a computationally efficient means of estimating confidence intervals (or variance) of cross-validated Area Under the ROC Curve (AUC) estimates. This allows you to generate confidence intervals in seconds, compared to other techniques that are many orders of magnitude slower.

In binary classification problems, the AUC is commonly used to evaluate the performance of a prediction model. Often, it is combined with cross-validation in order to assess how the results will generalize to an independent data set. In order to evaluate the quality of an estimate for cross-validated AUC, we obtain an estimate of its variance.

For massive data sets, the process of generating a single performance estimate can be computationally expensive. Additionally, when using a complex prediction method, the process of cross-validating a predictive model on even a relatively small data set can still require a large amount of computation time. Thus, in many practical settings, the bootstrap is a computationally intractable approach to variance estimation. As an alternative to the bootstrap, a computationally efficient influence curve based approach to obtaining a variance estimate for cross-validated AUC can be used.

The primary functions of the package are ci.cvAUC() and ci.pooled.cvAUC(), which report cross-validated AUC and compute confidence intervals for cross-validated AUC estimates based on influence curves for i.i.d. and pooled repeated measures data, respectively. One benefit to using influence curve based confidence intervals is that they require much less computation time than bootstrapping methods. The utility functions, AUC() and cvAUC(), are simple wrappers for functions from the ROCR package.

Erin LeDell, Maya L. Petersen & Mark J. van der Laan, "Computationally Efficient Confidence Intervals for Cross-validated Area Under the ROC Curve Estimates." (Electronic Journal of Statistics)

Install cvAUC

You can install:

  • The latest released version from CRAN with:

    install.packages("cvAUC")
  • The latest development version from GitHub with:

    remotes::install_github("ledell/cvAUC")

Using cvAUC

Here is a demo of how you can use the package, along with some benchmarks of the speed of the method. For a simpler example that runs faster, you can check out the help files for the various functions inside the R package.

In this example of the ci.cvAUC() function, we do the following:

  • Load an i.i.d. data set with a binary outcome.

  • We will use 10-fold cross-validation, so we need to divide the indices randomly into 10 folds. In this step, we stratify the folds by the outcome variable. Stratification is not necessary, but is commonly performed in order to create validation folds with similar distributions. This information is stored in a 10-element list called folds. Below, the function that creates the folds is called .cvFolds.

  • For the vth iteration of the cross-validation (CV) process, fit a model on the training data (i.e. observations in folds {1,...,10}\v) and then using this saved fit, generate predicted values for the observations in the vth validation fold. The .doFit() function below does this procedure. In this example, we use the Random Forest algorithm.

  • Next, the .doFit() function is applied across all 10 folds to generate the predicted values for the observations in each validation fold.

  • These predicted values are stored in vector called predictions, in the original order of the training observations..

  • Lastly, we use the ci.cvAUC() function to calculate CV AUC and to generate a 95% confidence interval for this CV AUC estimate.

First, we define a few utility functions:

.cvFolds <- function(Y, V){
  # Create CV folds (stratify by outcome)	
  Y0 <- split(sample(which(Y==0)), rep(1:V, length = length(which(Y==0))))
  Y1 <- split(sample(which(Y==1)), rep(1:V, length = length(which(Y==1))))
  folds <- vector("list", length = V)
  for (v in seq(V)) {folds[[v]] <- c(Y0[[v]], Y1[[v]])}  	
  return(folds)
}

.doFit <- function(v, folds, train, y){
  # Train & test a model; return predicted values on test samples
  set.seed(v)
  ycol <- which(names(train) == y)
  params <- list(x = train[-folds[[v]], -ycol],
                 y = as.factor(train[-folds[[v]], ycol]),
                 xtest = train[folds[[v]], -ycol])
  fit <- do.call(randomForest, params)
  pred <- fit$test$votes[,2]
  return(pred)
}

This function will execute the example:

iid_example <- function(train, y = "response", V = 10, seed = 1) {
  
  # Create folds
  set.seed(seed)
  folds <- .cvFolds(Y = train[,c(y)], V = V)
  
  # Generate CV predicted values
  cl <- makeCluster(detectCores())
  registerDoParallel(cl)
  predictions <- foreach(v = 1:V, .combine = "c", 
    .packages = c("randomForest"),
    .export = c(".doFit")) %dopar% .doFit(v, folds, train, y)
  stopCluster(cl)
  predictions[unlist(folds)] <- predictions

  # Get CV AUC and 95% confidence interval
  runtime <- system.time(res <- ci.cvAUC(predictions = predictions, 
                                         labels = train[,c(y)],
                                         folds = folds, 
                                         confidence = 0.95))
  print(runtime)
  return(res)
}

Load a sample binary outcome training set into R with 10,000 rows:

train_csv <- "https://erin-data.s3.amazonaws.com/higgs/higgs_train_10k.csv"
train <- read.csv(train_csv, header = TRUE, sep = ",")

Run the example:

library(randomForest)
library(doParallel)  # to speed up the model training in the example
library(cvAUC)

res <- iid_example(train = train, y = "response", V = 10, seed = 1)
#   user  system elapsed 
#  0.096   0.005   0.102 

print(res)
# $cvAUC
# [1] 0.7818224
# 
# $se
# [1] 0.004531916
# 
# $ci
# [1] 0.7729400 0.7907048
# 
# $confidence
# [1] 0.95

cvAUC Performance

For the example above (10,000 observations), it took ~0.1 seconds to calculate the cross-validated AUC and the influence curve based confidence intervals. This was benchmarked on a 3.1 GHz Intel Core i7 processor using cvAUC package version 1.1.3.

For bigger (i.i.d.) training sets, here are a few rough benchmarks:

  • 100,000 observations: ~0.4 seconds
  • 1 million observations: ~5.0 seconds

To try it on bigger datasets yourself, feel free to replace the 10k-row training csv with either of these files here:

train_csv <- "https://erin-data.s3.amazonaws.com/higgs/higgs_train_100k.csv"
train_csv <- "https://erin-data.s3.amazonaws.com/higgs/higgs_train_1M.csv"  

cvauc's People

Contributors

ledell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cvauc's Issues

add license file

According to the Apache 2.0 license, a license file must be shipped with the package but it currently has none. Can you please add one?

ci.cvAUC needs 0.5 for ties

Hi Erin,
The ROCR package's calculation of the AUC assigns 0.5 points for a tie. I was looking at your code for calculating the CIs, and saw that it ignores that possibility. Although people argue over strategies for dealing with ties, since the code is estimating the variance of the cv-AUC, as calculated by the ROCR package, it ought to respect the underlying calculation of the AUC.

DT[, :=(icVal, ifelse(label == pos, w1 * (fracNegLabelsWithSmallerPreds - auc), w0 * (fracPosLabelsWithLargerPreds - auc)))]

For some positive observation, i, this line will assign w1 * 1 to each negLabel earlier in the ordering, when for some subset of those it should possibly be w1 * 0.5. Also, there may be one or more negLabel observations immediately after i in the ordering that should be counted as 0.5, instead of 0. (Of course, similar logic applies to the negative label calculations.)

--Susan Gruber

Repeated CV

Have you considered supporting repeated cross-validation, to further reduce the impact of random variation in folds? Seems pretty trivial to implement but is a good practice. It is a commonly used feature of Max Kuhn's caret package.

Observation weights?

Hi Erin,

We're interested in getting AUC confidence intervals for a SuperLearner project where we use observation weights. I didn't happen to see an argument for that but wanted to check if you might have code somewhere or other reaction to using observation weights? Or maybe we can get a PR together if it isn't too tricky to implement.

Thanks,
Chris

Code for use with SuperLearner

Hi Erin,

I'm interested in using cvAUC with SuperLearner on a project with Alan and I was wondering, do you happen to have any existing code for that, or advice? Am I right in thinking that one would need to use this with CV.SuperLearner to get the correct cross-validated folds to estimate the AUC CI of the final SL itself? I have skimmed your dissertation on cvAUC but need to read into it more closely.

Appreciate it,
Chris

Submit v1.1.4 to CRAN

Version 1.1.3-1.1.4 has some minor metadata changes:

  • Move data.table and ROCR from Depends to Imports
  • Update maintainer email with CRAN from old berkeley.edu email to current email
  • Update NEWS file (maybe change/add NEWS.md)
  • Remove startup note: Notice to cvAUC users: Major speed improvements in version 1.1.0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.