Giter Site home page Giter Site logo

beat's Introduction

BEAT (Bias-Eliminating Adapted Trees)

Eva Ascarza and Ayelet Israeli

Forked from https://github.com/grf-labs/grf

Installation

This package is not available on CRAN and must be built from source. On Windows you will need RTools from https://cran.r-project.org/bin/windows/Rtools/ and on OS X you will need the developer tools documented in https://mac.r-project.org/tools/.

Please note that this package only runs on R-4.2.3 or earlier, see installation instructions below.

First install the devtools package, then use it to install beat:

install.packages(c("devtools", "ggplot2")) ## ggplot2 only needed for example
devtools::install_version("RcppEigen", "0.3.3.7.0") ## beat does not work with newer RcppEigen
devtools::install_version("RcppArmadillo", "0.11.4.0.1") ## beat does not work with newer RcppArmadillo
devtools::install_github("ayeletis/beat")  ## do not update the RcppEigen or RcppArmadillo package if prompted

Changes relative to GRF

The package offers three new functions: balanced_causal_forest, balanced_regression_forest, balanced_probability_forest.

All arguments are the same as the original package, but there are two new inputs: target.weight.penalty indicates the penalty assigned to the protected attributes. target.weights is a matrix that includes the protected characteristics. X should not inlcude the protected characteristics.

See full details about the BEAT method in the original paper: Eliminating unintended bias in personalized policies using bias-eliminating adapted trees (BEAT)

Sample Usage

library(beat)
library(data.table)
library(ggpubr)

rm(list=ls())


## ----------------------------------------------
##   Simulate some data
## ----------------------------------------------

n1 = 1000; #calibration 
n2 = 1000; #validation
p_continuous = 4  # number of continuous features (unprotected)
p_discrete = 3  # number of discrete features (unprotected)
p_demog = 1 # number of protected attributes
n = n1 + n2

# Features (unprotected)
X_cont = matrix(rnorm(n*p_continuous), n, p_continuous)
X_disc = matrix(rbinom(n*p_discrete, 1, 0.3),n,p_discrete)
X = cbind(X_cont,X_disc)

# Protected attributes, discrete and continuous, where the first one is correlated with X[,2]
Z = rbinom(n, 1, 1/(1+exp(-X_cont[,2])))

# Tau -- in this example in depends on X[2] but no on Z
tau <- (-1 + pmax(X[,1], 0) + X[,2] + abs(X[,3]) + X[,5]) 

# Random assignment
W = rbinom(n, 1, 0.5)

# Output for regression forest (no treatment)
Y_r =  X[,1] - 2*X[,2] + X[,4] + 3*Z + runif(n)   # Y is function of X, Z(demo)

# Output for causal forest
Y =  Y_r + tau*W   # Y is function of X, Z(demo), tau*W

train_data = data.frame(Y=Y[c(1:n1)], 
                        Z=Z[c(1:n1)], 
                        W=W[c(1:n1)], 
                        X=X[c(1:n1),], 
                        tau = tau[c(1:n1)], 
                        Y_r = Y_r[c(1:n1)])
test_data = data.frame(Y=Y[c((n1+1):(n1+n2))],
                       Z=Z[c((n1+1):(n1+n2))],
                       W=W[c((n1+1):(n1+n2))],
                       X=X[c((n1+1):(n1+n2)),], 
                       tau = tau[c((n1+1):(n1+n2))],
                       Y_r = Y_r[c((n1+1):(n1+n2))])

Xcols = grep("X", names(train_data), value=TRUE)
Zcols =grep('Z', names(train_data), value=TRUE)
  
  
## train
X_train = train_data[,c(4:10)]
W_train = train_data$W
Z_train = train_data[,2]
Y_train = train_data$Y
Y_r_train = train_data$Y_r

## test
X_test = test_data[,c(4:10)]
Z_test = test_data$Z

## model specs
num_trees = 2000
my_penalty = 10 # When penalty = 0 it corresponds to GRF

## ----------------------------------------------
##   Estimate Balanced Causal Forest 
## ----------------------------------------------
fit_causal_beat <- balanced_causal_forest(X_train, Y_train, W_train,
                                     target.weights = as.matrix(Z_train),
                                     target.weight.penalty = my_penalty,
                                     num.trees = num_trees)
  

## Predict CBT causal scores
cbt_causal_train = predict(fit_causal_beat)$predictions
cbt_causal_test = predict(fit_causal_beat, X_test)$predictions


## ----------------------------------------------
##   Estimate Balanced Regression Forest 
## ----------------------------------------------
fit_regression_beat <- balanced_regression_forest(X_train, Y_r_train,
                                       target.weights = as.matrix(Z_train),
                                       target.weight.penalty = my_penalty,
                                       num.trees = num_trees)

## Predict CBT regression scores
cbt_regression_train = predict(fit_regression_beat)$predictions
cbt_regression_test = predict(fit_regression_beat, X_test)$predictions


## ----------------------------------------------
##   Check balance in test scores
## ----------------------------------------------
dat.plot = data.table(cbt_causal = cbt_causal_test,
                      cbt_regr = cbt_regression_test,
                      true_causal = test_data$tau,
                      true_reg = test_data$Y_r,
                      Z = as.factor(Z_test))


p1 = ggdensity(data=dat.plot,
                   x='true_causal', color='Z', fill='Z', alpha=0.2, add = "mean", title='true causal')

p2 = ggdensity(data=dat.plot,
                   x='cbt_causal', color='Z', fill='Z', alpha=0.2, add = "mean", title='cbt causal')

p3 = ggdensity(data=dat.plot,
          x='true_reg', color='Z', fill='Z', alpha=0.2, add = "mean", title='true regression')

p4 = ggdensity(data=dat.plot,
          x='cbt_regr', color='Z', fill='Z', alpha=0.2, add = "mean", title='cbt regression')

require(gridExtra)
grid.arrange(p1, p2, p3, p4, ncol=2)


 

beat's People

Contributors

ayeletis avatar yu45020 avatar izahn avatar

Stargazers

 avatar Ali Cirik avatar Yuta Kanzawa avatar

Forkers

etchcheng aliavni

beat's Issues

beat should only include the 3 new functions

We do not want to override the existing grf functions. If you install beat it overrides and makes available all the grf functions. we do not want that. Instead, when you install beat we want only 3 functions to become available to users:
balanced_causal_forest, balanced_regression_forest, and balanced_probability_forest.

compilation failed

Trying to install beat on my MAC laptop and use required version of packages/R version.
Got the following message...

...
In file included from RcppExports.cpp:5:
In file included from /Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include/RcppEigen.h:25:
In file included from /Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include/RcppEigenForward.h:37:
In file included from /Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include/unsupported/Eigen/SparseExtra:51:
/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include/Eigen/src/Core/util/ReenableStupidWarnings.h:14:30: warning: pragma diagnostic pop could not pop, no matching push [-Wunknown-pragmas]
#pragma clang diagnostic pop
^
9 warnings and 1 error generated.
make: *** [src/forest/ForestPredictors.o] Error 1
21 warnings generated.
ERROR: compilation failed for package ‘beat’

  • removing ‘/Library/Frameworks/R.framework/Versions/4.2/Resources/library/beat’
    Warning message:
    In i.p(...) :
    installation of package ‘/var/folders/bb/ntfl62j95g5bm2lf2q2jlf2h0000gp/T//Rtmp6jhEgf/file31397bcbdee6/beat_1.2.4.tar.gz’ had non-zero exit status

remove penalty metrics

The argument "target.weight.penalty.metric" should be eliminated.
Instead, we would like to just use the default "split_l2_norm_rate".

the three effected functions are: balanced_causal_forest, balanced_regression_forest, and balanced_probability_forest.

also remove this section from the Readme.

Change the name

The name of the package should be changed from grf to beat.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.