Giter Site home page Giter Site logo

sin's Introduction

Causal Effect Inference for Structured Treatments

Overview

We address the estimation of conditional average treatment effects (CATEs) for structured treatments (e.g., graphs, images, texts). Given a weak condition on the effect, we propose the generalized Robinson decomposition, which (i) isolates the causal estimand (reducing regularization bias), (ii) allows one to plug in arbitrary models for learning, and (iii) possesses a quasi-oracle convergence guarantee under mild assumptions. In experiments with small-world and molecular graphs we demonstrate that our approach outperforms prior work in CATE estimation.

Link to paper

Requirements

We tested the implementation in Python 3.8.

Dependencies

requirements.txt is an automatically generated file with all dependencies.

Essential packages include:

rdkit
numpy
networkx
scikit-learn
torch
pyg
wandb

Datasets

The TCGA simulation requires the TCGA and QM9 datasets. The code automatically downloads and unzips these datasets if they do not exist. Alternatively, the TCGA dataset can be downloaded from here and the QM9 dataset from here. Both datasets should be located in data/tcga/.

Entry points

There are three runnable python scripts:

  • generate_data.py: Generates and saves a dataset given the configuration in configs/generate_data/.
    • Stores generated data in data_path with folder structure {data_path}/{task}/seed-{seed}/bias-{bias}/
    • For each task, seed, and bias combination, generates and stores a new dataset
  • run_model_training.py: Trains and evaluates a CATE estimation model given the configuration in configs/run_model/.
    • Evaluation results will be logged, can be saved to results_path and/or synced to a wandb.ai account
  • run_hyperparameter_sweeping.py Sweeps hyper-parameters with wandb as specified in configs/sweeps/
  • run_unseen_treatment_update.py: Runs the GNN baseline on a specified dataset and updates one-hot encodings of previously unseen treatments in the test set to the closest ones seen during training based on their Euclidean space in the hidden embedding space.
    • Before running the CAT baseline, run this script. Otherwise, unseen treatment one-hot encodings will be fed into the network.

Quick tour

generate_data.py

Important arguments

  • task: Simulation sw or tcga
  • bias: Treatment selection bias coefficient
  • seed: Random seed
  • data_path: Path to save/load generated datasets

run_model.py

Important arguments

  • task: Simulation sw or tcga
  • model: SIN, gnn, cat, graphite, zero
  • bias: Treatment selection bias coefficient
  • seed: Random seed

Remarks

TCGA Simulation warnings

When parsing smiles from the QM9 dataset for simulating a TCGA experiment, there may be bad input warnings for certain molecules. The data generator will ignore these molecules. When subsampling 10k molecules, we noticed that there are around ~1% faulty molecules.

Hyper-parameter tuning and experiment management

For hyper-parameter tuning and experiment management, we use the wandb package. Please note that for both tasks, you need an account on wandb.ai. If you want to run single experiments, you can do so without an account - in this case, please ignore the warnings.

sin's People

Contributors

jeankaddour avatar

Stargazers

 avatar Piergiuseppe Mallozzi avatar AndrewYq avatar Jeremy Sun avatar  avatar GUGU avatar  avatar yuno avatar  avatar Jack Xin avatar cin-hubert avatar Hanlin Zhang avatar Lucas Abbade avatar Sharkay one avatar Farah A. avatar 爱可可-爱生活 avatar Coolixz avatar Adi Lin avatar  avatar Meng Pan avatar Peiqi (Mark) Wang avatar Jan Bours avatar Sun Aries avatar qiangsiwei avatar Ruocheng Guo avatar Yichuan LI avatar  avatar Ya-Lin Zhang avatar  avatar zzzqqqyyy avatar Zifeng Ding avatar Andrew Chan avatar Georvic Tur avatar  avatar PhuPing avatar Sheng Shen avatar  avatar Roger GOU avatar Hanchen avatar Qi Liu avatar

Watchers

 avatar

sin's Issues

Cannot Reproduce the results in the paper.

I generate the tcga dataset with arg.bias=1 and follow the default hyper-parameter in the repo. However, the performance of SIN is quite different from the results in the manuscript.
image

At the same time, other baselines achieve better performance. Is I neglect some important settings?
Results for graphite:
image

Update requirements.txt file

Good repo! However, there are some errors when trying to install with requirements.txt. Seems some versions are in conflict or out of date. Also, SQLAlchem should be SQLAlchemy.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.