Giter Site home page Giter Site logo

clinicaldg's Introduction

An Empirical Framework for Domain Generalization In Clinical Settings

Paper

If you use this code in your research, please cite the following publication:

@inproceedings{zhang2021empirical,
  title={An empirical framework for domain generalization in clinical settings},
  author={Zhang, Haoran and Dullerud, Natalie and Seyyed-Kalantari, Laleh and Morris, Quaid and Joshi, Shalmali and Ghassemi, Marzyeh},
  booktitle={Proceedings of the Conference on Health, Inference, and Learning},
  pages={279--290},
  year={2021}
}

This paper can also be found on arxiv: https://arxiv.org/abs/2103.11163

Acknowledgements

Our implementation is a modified version of the excellent DomainBed framework (from commit a10458a). We also make use of some code from eICU Benchmarks.

To replicate the experiments in the paper:

Step 0: Environment and Prerequisites

Run the following commands to clone this repo and create the Conda environment:

git clone https://github.com/MLforHealth/ClinicalDG.git
cd ClinicalDG/
conda env create -f environment.yml
conda activate clinicaldg

Step 1: Obtaining the Data

See DataSources.md for detailed instructions.

Step 2: Running Experiments

Experiments can be ran using the same procedure as for the DomainBed framework, with a few additional adjustable data hyperparameters which should be passed in as a JSON formatted dictionary.

For example, to train a single model:

python -m clinicaldg.scripts.train\
       --algorithm ERM\
       --dataset eICUSubsampleUnobs\
       --es_method val\
       --hparams  '{"eicu_architecture": "GRU", "eicu_subsample_g1_mean": 0.5, "eicu_subsample_g2_mean": 0.05}'\
       --output_dir /path/to/output

To sweep a range of datasets, algorithms, and hyperparameters:

python -m clinicaldg.scripts.sweep launch\
       --output_dir=/my/sweep/output/path\
       --command_launcher slurm\
       --algorithms ERMID ERM IRM VREx RVP IGA CORAL MLDG GroupDRO \
       --datasets CXR CXRBinary\
       --n_hparams 10\
       --n_trials 5\
       --es_method train\
       --hparams '{"cxr_augment": 1}'

A detailed list of hparams available for each dataset can be found here.

We provide the bash scripts used for our main experiments in the bash_scripts directory. You will likely need to customize them, along with the launcher, to your compute environment.

Step 3: Aggregating Results

We provide sample code for creating aggregate results for an experiment in notebooks/AggResults.ipynb.

License

This source code is released under the MIT license, included here.

clinicaldg's People

Contributors

hzhang0 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

clinicaldg's Issues

Low GPU Utilization for CXRBinary

While attempting to train a model for the CXRBinary task, I'm getting very low GPU utilisation (~5%), with batches taking a very long time to load onto the GPU.

I'm using 4 CPUs and a 32GB V100, with num_workers=4 on the dataloaders.

Strangely, the batches don't iterate evenly: I'll get four batches run through quickly, then a long pause, then another four batches, etc.

Did you encounter something like this? How did you maximise your GPU process/memory utilisation?

Training script:

hparams='{"batch_size": 64}'

python -m clinicaldg.scripts.train \
    --algorithm ERM \
    --dataset CXRBinary \
    --output_dir /scratch/rc4499/thesis/output \
    --es_method train \
    --hparams "${hparams}" \
    --max_steps 5000 \
    --checkpoint_freq 1000

What is map.csv w.r.t. CheXpert?

Hi,

I'm trying to replicate your results and am setting up the datasets, however, am confused about the map.csv file referenced in the validation/preprocessing scripts for CheXpert.

In CheXpert I only have train.csv and valid.csv, not map.csv. What is this file? Is it a concatenation of train.csv and valid.csv? Alternatively, can you link to the map.csv file you used in your experiments?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.