Giter Site home page Giter Site logo

caduceus's Introduction

Caduceus

Caduceus ☤: Bi-Directional Equivariant Long-Range DNA Sequence Modeling

[Blog]   |   [arXiv]   |   [HuggingFace 🤗]

This repository contains code for reproducing the results in the paper "Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling," Schiff et al. (2024).

Using Caduceus with 🤗

We have uploaded a pre-trained Caduceus model to the Huggingface hub. The available models are:

You can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model.

To use the pre-trained model for masked language modeling, use the following snippet:

from transformers import AutoModelForMaskedLM, AutoTokenizer

# See the `Caduceus` collection page on the hub for list of available models.
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Alternatively, you can instantiate a model from scratch to train on your own data as follows:

from transformers import AutoConfig, AutoModelForMaskedLM

# Add any config overrides here, see the `config.json` file on the hub for details.
config_overrides = {}
# See the `Caduceus` collection page on the hub for list of available models.
config = AutoConfig.from_pretrained(
 "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
 **config_overrides,
)
model = AutoModelForMaskedLM.from_config(config)

Getting started in this repository

To get started, create a conda environment containing the required dependencies.

conda env create -f caduceus_env.yml

Activate the environment.

conda activate caduceus_env

Create the following directories to store saved models and slurm logs:

mkdir outputs
mkdir watch_folder

Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the train.py script. We also provide sample slurm scripts for launching pre-training and downstream fine-tuning experiments in the slurm_scripts/ directory.

Pretraining on Human Reference Genome

(Data downloading instructions are copied from HyenaDNA repo)

First, download the Human Reference Genome data. It's comprised of 2 files, 1 with all the sequences (the .fasta file), and with the intervals we use (.bed file).

The file structure should look like

data
|-- hg38/
    |-- hg38.ml.fa
    |-- human-sequences.bed

Download fasta (.fa format) file (of the entire human genome) into ./data/hg38. ~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically. Then download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file).

mkdir -p data/hg38/
curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz
gunzip data/hg38/hg38.ml.fa.gz  # unzip the fasta file
curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed

Launch pretraining run using the command line

python -m train \
  experiment=hg38/hg38 \
  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
  dataset.max_length=1024 \
  dataset.batch_size=1024 \
  dataset.mlm=true \
  dataset.mlm_probability=0.15 \
  dataset.rc_aug=false \
  model=caduceus \
  model.config.d_model=128 \
  model.config.n_layer=4 \
  model.config.bidirectional=true \
  model.config.bidirectional_strategy=add \
  model.config.bidirectional_weight_tie=true \
  model.config.rcps=true \
  optimizer.lr="8e-3" \
  train.global_batch_size=8 \
  trainer.max_steps=10000 \
  +trainer.val_check_interval=10000 \
  wandb=null

or alternatively, if using a cluster that has slurm installed, adapt the scripts below:

slurm_scripts
|-- run_pretrain_caduceus.sh
|-- run_pretrain_hyena.sh
|-- run_pretrain_mamba.sh

and run the training as a batch job:

cd slurm_scripts
sbatch run_pretrain_caduceus.sh

GenomicBenchmarks

The GenomicBenchmarks presented in Grešová et al. (2023) is comprised of 8 classification tasks.

We can launch a downstream fine-tuning run on one of the tasks using the sample command below:

python -m train \
    experiment=hg38/genomic_benchmark \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="dummy_mouse_enhancers_ensembl" \
    dataset.train_val_split_seed=1 \
    dataset.batch_size=128 \
    dataset.rc_aug=false \
    +dataset.conjoin_train=false \
    +dataset.conjoin_test=false \
    loader.num_workers=2 \
    model=caduceus \
    model._name_=dna_embedding_caduceus \
    +model.config_path="<path to model_config.json>" \
    +model.conjoin_test=false \
    +decoder.conjoin_train=true \
    +decoder.conjoin_test=false \
    optimizer.lr="1e-3" \
    trainer.max_epochs=10 \
    train.pretrained_model_path="<path to .ckpt file>" \
    wandb=null

This sample run will fine-tune a pre-trained Caduceus-PS model on the dummy_mouse_enhancers_ensembl task. Note some of the additional arguments present here, relative to the pre-training command from above:

  • model.config_path contains the path model config that was saved during pre-training. This will be saved to the run directory of the pre-training experiment.
  • train.pretrained_model_path contains the path to the pre-trained model checkpoint.
  • dataset.conjoin_train determines whether the dataset will return a single sequence (dataset.conjoin_train=false) or the concatenation of a sequence and its reverse complement along dim=-1, during downstream fine-tuning training.
  • dataset.conjoin_test is the same as above, but for inference (e.g., validation / test).
  • decoder.conjoin_train determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape (batch_size, seq_len, d_model) or (batch_size, seq_len, d_model, 2) during downstream fine-tuning training. When set to true the decoder is run on input[..., 0] and input[..., 1] and the results are averaged to produce the final prediction.
  • dataset.conjoin_test is the same as above, but for inference (e.g., validation / test).

Note this benchmark only contains a training and test split for each task. Therefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the dataset.train_val_split_seed argument. We perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds.

As with pre-training, we can also launch the fine-tuning run as a batch job using the provided run_genomic_benchmark.sh script. We also provide a helper shell script wrapper_run_genomics.sh that can be used to launch multiple fine-tuning runs in parallel.

Finally, the run_genomics_benchmark_cnn.sh script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks.

Nucleotide Transformer datasets

The Nucleotide Transformer suite of tasks was proposed in Dalla-Torre et al. (2023). The data is available on HuggingFace: InstaDeepAI/nucleotide_transformer_downstream_tasks.

We can launch a downstream fine-tuning run on one of the tasks using the sample command below:

python -m train \
    experiment=hg38/nucleotide_transformer \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="${task}" \
    dataset.train_val_split_seed=${seed} \
    dataset.batch_size=${batch_size} \
    dataset.rc_aug="${rc_aug}" \
    +dataset.conjoin_test="${CONJOIN_TEST}" \
    loader.num_workers=2 \
    model._name_=dna_embedding_caduceus \
    +model.config_path="<path to model_config.json>" \
    +model.conjoin_test=false \
    +decoder.conjoin_train=true \
    +decoder.conjoin_test=false \
    optimizer.lr="1e-3" \
    trainer.max_epochs=10 \
    train.pretrained_model_path="<path to .ckpt file>" \
    trainer.max_epochs=20 \
    wandb=null

We can also launch as batch jobs (see run_nucleotide_transformer.sh and wrapper_run_nucleotide_transformer.sh for details).

eQTL SNP Variant Effect Prediction

TODO: This section is a work in progress and is pending the open source release of the dataset used in the corresponding experiment. We plan to update the details here soon!

Citation

If you find our work useful, please cite our paper using the following:

@article{schiff2024caduceus,
  title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling},
  author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr},
  journal={arXiv preprint arXiv:2403.03234},
  year={2024}
}

Acknowledgements

This repository is adapted from the HyenaDNA repo and leverages much of the training, data loading, and logging infrastructure defined there. HyenaDNA was originally derived from the S4 and Safari repositories.

We would like to thank Evan Trop and the InstaDeep team for useful discussions about the Nucleotide Transformer leaderboard.

Finally, we would like to thank MosaicML for providing compute resources for some of the pre-training experiments.

caduceus's People

Contributors

yair-schiff avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.