Giter Site home page Giter Site logo

instadeepai / nucleotide-transformer Goto Github PK

View Code? Open in Web Editor NEW
394.0 18.0 45.0 4.14 MB

๐Ÿงฌ Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics

Home Page: https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2

License: Other

Python 100.00%
genomics deep-learning language-model nucleotide transformer dna

nucleotide-transformer's Introduction

Nucleotide Transformers and SegmentNT

Python Version Jax Version license

Welcome to this InstaDeep Github repository, where are featured:

  1. A collection of transformer based genomic language models from both of our research works, The Nucleotide Transformer and Agro Nucleotide Transformer.
  2. A collection of segmentation models using the Nucleotide Transformers as a backbone, allowing segmentation of a dna sequence's genomic elements at single-nucleotide resolution: the SegmentNT models.

We are thrilled to open-source these works and provide the community with access to the code and pre-trained weights for these nine genomics language models and 2 segmentation models. Models from The Nucleotide Transformer project were developed in collaboration with Nvidia and TUM, and the models were trained on DGX A100 nodes on Cambridge-1. The model from the Agro Nucleotide Transformer project was develop in collaboration with Google, and the model trained on TPU-v4 accelerators.

Overall, our works provides novel insights related to the pretraining and application of language foundational models, as well as the training of models using them as a backbone encoder, to genomics with ample opportunities of their applications in the field.

In this repository, you will find the following:

  • Inference code for our models
  • Pre-trained weights for all 9 NT models and 2 SegmentNT models
  • Instructions for using the code and pre-trained models

The Nucleotide Transformer Models

Compared to other approaches, our models do not only integrate information from single reference genomes, but leverage DNA sequences from over 3,200 diverse human genomes, as well as 850 genomes from a wide range of species, including model and non-model organisms. Through robust and extensive evaluation, we show that these large models provide extremely accurate molecular phenotype prediction compared to existing methods.

Performance on downstream tasks

Fig. 1: The Nucleotide Transformer model accurately predicts diverse genomics tasks after fine-tuning. We show the performance results across downstream tasks for fine-tuned transformer models. Error bars represent 2 SDs derived from 10-fold cross-validation.

Agro Nucleotide Transformer Model

In this work we present a novel foundational large language model trained on reference genomes from 48 plant species with a predominant focus on crop species. We assessed the performance of AgroNT across several prediction tasks ranging from regulatory features, RNA processing, and gene expression, and show that AgroNT can obtain state-of-the art performance.

AgroNT Performance on Gene Expression

Fig. 2: AgroNT provides gene expression prediction across different plant species. Gene expression prediction on holdout genes across all tissues are correlated with observed gene expression levels. The coefficient of determination (R2) from a linear model and associated P -values between predicted and observed values are shown.

Get started ๐Ÿš€

To use the code and pre-trained models, simply:

  1. Clone the repository to your local machine.
  2. Install the package by running pip install ..

You can then download and do the inference with any of our nine models in only a few lines of codes:

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="500M_human_ref",
    embeddings_layers_to_save=(20,),
    max_positions=32,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

# Get embeddings at layer 20
print(outs["embeddings_20"].shape)

Supported model names are:

  • 500M_human_ref
  • 500M_1000G
  • 2B5_1000G
  • 2B5_multi_species
  • 50M_multi_species_v2
  • 100M_multi_species_v2
  • 250M_multi_species_v2
  • 500M_multi_species_v2
  • 1B_agro_nt

You can also run our models and find more example code in google colab Open All Collab

The code runs both on GPU and TPU thanks to Jax!

Nucleotide Transformers v2 models

Our second version Nucleotide Transformer v2 models include a series of architectural changes that proved more efficient: instead of using learned positional embeddings, we use Rotary Embeddings that are used at each attention layer and Gated Linear Units with swish activations without bias. These improved models also accept sequences of up to 2,048 tokens leading to a longer context window of 12kbp. Inspired by Chinchilla scaling laws, we also trained our NT-v2 models on our multi-species dataset for longer duration (300B tokens for the 50M and 100M models; 1T tokens for the 250M and 500M model) compared to the v1 models (300B tokens for all four models).

Embeddings retrieval

The transformer layers are 1-indexed, which means that calling get_pretrained_model with the arguments model_name="500M_human_ref" and embeddings_layers_to_save=(1, 20,) will result in extracting embeddings after the first and 20-th transformer layer. For transformers using the Roberta LM head, it is common practice to extract the final embeddings after the first layer norm of the LM head rather than after the last transformer block. Therefore, if get_pretrained_model is called with the following arguments embeddings_layers_to_save=(24,), the embeddings will not be extracted after the final transformer layer but rather after the first layer norm of the LM head.


The SegmentNT Models

SegmentNT models leverage a Nucleotide Transformer (NT) transformer from which we removed the language model head and replaced by a 1-dimensional U-Net segmentation head to predict the location of several types of genomics elements in a sequence at a single nucleotide resolution. We present two different model variants on 14 different classes of human genomics elements in input sequences up to 30kb. These include gene (protein-coding genes, lncRNAs, 5โ€™UTR, 3โ€™UTR, exon, intron, splice acceptor and donor sites) and regulatory (polyA signal, tissue-invariant and tissue-specific promoters and enhancers, and CTCF- bound sites) elements. SegmentNT achieves superior performance over the state-of-the-art U-Net segmentation architecture, benefiting from the pre-trained weights of NT, and demonstrates zero-shot generalization up to 50kbp.

Performance on downstream tasks

Fig. 1: SegmentNT localizes genomics elements at nucleotide resolution.

Get started ๐Ÿš€

To use the code and pre-trained models, simply:

  1. Clone the repository to your local machine.
  2. Install the package by running pip install ..

You can then download and infer on a sequence with any of our models in only a few lines of codes:

โš ๏ธ The SegmentNT models have been trained on a sequences of 30,000 nucleotides, or 5001 tokens (accounting for the CLS token). However, SegmentNT has been shown to generalize up to sequences of 50,000 bp. For training on 30,000 bps, which is a length superior than the maximum length of 2048 6-mers tokens that the nucleotide transformer can handle, Yarn rescaling is employed. By default, the rescaling factor is set to the one used during the training. In case you need to infer on sequences between 30kbp and 50kbp, make sure to pass the rescaling_factor argument in the get_pretrained_segment_nt_model function with the value rescaling_factor = max_num_nucleotides / max_num_tokens_nt where num_dna_tokens_inference is the number of tokens at inference (i.e 6669 for a sequence of 40008 base pairs) and max_num_tokens_nt is the max number of tokens on which the backbone nucleotide-transformer was trained on, i.e 2048.

๐Ÿ” The notebook examples/inference_segment_nt.ipynb showcases how to infer on a 50kb sequence and plot the probabilities to reproduce the Fig.3 of the paper.

๐Ÿšง The SegmentNT models do not handle any "N" in the input sequence because each nucleotides need to be tokenized as 6-mers, which can not be the case when using sequences containing one or multiple "N" base pairs.

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "gpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# 2 to the power of the number of downsampling block, i.e 4.
max_num_nucleotides = 8

assert max_num_nucleotides % 4 == 0, (
    "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
     "2 to the power of the number of downsampling block, i.e 4.")

parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_nucleotides + 1,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Infer on the sequence
outs = apply_fn(parameters, keys, tokens)
# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

print(f"Features inferred: {config.features}")

# Get probabilities associated with intron
idx_intron = config.features.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

Supported model names are:

  • segment_nt
  • segment_nt_multi_species

The code runs both on GPU and TPU thanks to Jax!


Tokenization ๐Ÿ”ค

The models are trained on sequences of length up to 1000 tokens, including the <CLS> token prepended automatically to the beginning of the sequence. The tokenizer starts tokenizing from left to right by grouping the letters "A", "C", "G" and "T" in 6-mers. The "N" letter is chosen not to be grouped inside the k-mers, therefore whenever the tokenizer encounters a "N", or if the number of nucleotides in the sequence is not a multiple of 6, it will tokenize the nucleotides without grouping them. Examples are given below:

dna_sequence_1 = "ACGTGTACGTGCACGGACGACTAGTCAGCA"
tokenized_dna_sequence_1 = [<CLS>,<ACGTGT>,<ACGTGC>,<ACGGAC>,<GACTAG>,<TCAGCA>]

dna_sequence_2 = "ACGTGTACNTGCACGGANCGACTAGTCTGA"
tokenized_dna_sequence_2 = [<CLS>,<ACGTGT>,<A>,<C>,<N>,<TGCACG>,<G>,<A>,<N>,<CGACTA>,<GTCTGA>]

All the v1 and v2 transformers can therefore take sequences of up to 5994 and 12282 nucleotides respectively if there are no "N" inside.


HuggingFace ๐Ÿค—

The collection of models presented in this repository are available on Instadeep's huggingface spaces here: The Nucleotide Transformers space and Agro Nucleotide Transformer space!

  • Nucleotide Transformer: Two example notebooks showing how to finetune any of the models with regular finetuning and with LoRA on any of the Nucleotide Transformer tasks are also available in HuggingFace example notebooks.
  • SegmentNT: An inference notebook shows how to use the torch SegmentNT model to infer on a given 50kb sequence.

Acknowledgments ๐Ÿ™

We thank Maลกa Roller, as well as members of the Rostlab, particularly Tobias Olenyi, Ivan Koludarov, and Burkhard Rost for constructive discussions that helped identify interesting research directions. Furthermore, we extend gratitude to all those who deposit experimental data in public databases, to those who maintain these databases, and those who make analytical and predictive methods freely available. We also thank the Jax development team.

Citing our works ๐Ÿ“š

If you find this repository useful in your work, please add a relevant citation to either of our associated papers:

The Nucleotide Transformer paper:

@article{dalla2023nucleotide,
  title={The Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics},
  author={Dalla-Torre, Hugo and Gonzalez, Liam and Mendoza Revilla, Javier and Lopez Carranza, Nicolas and Henryk Grywaczewski, Adam and Oteri, Francesco and Dallago, Christian and Trop, Evan and Sirelkhatim, Hassan and Richard, Guillaume and others},
  journal={bioRxiv},
  pages={2023--01},
  year={2023},
  publisher={Cold Spring Harbor Laboratory}
}

Agro Nucleotide Transformer paper:

@article{mendoza2023foundational,
  title={A Foundational Large Language Model for Edible Plant Genomes},
  author={Mendoza-Revilla, Javier and Trop, Evan and Gonzalez, Liam and Roller, Masa and Dalla-Torre, Hugo and de Almeida, Bernardo P and Richard, Guillaume and Caton, Jonathan and Lopez Carranza, Nicolas and Skwark, Marcin and others},
  journal={bioRxiv},
  pages={2023--10},
  year={2023},
  publisher={Cold Spring Harbor Laboratory}
}

SegmentNT paper

@article{de2024segmentnt,
  title={SegmentNT: annotating the genome at single-nucleotide resolution with DNA foundation models},
  author={de Almeida, Bernardo P and Dalla-Torre, Hugo and Richard, Guillaume and Blum, Christopher and Hexemer, Lorenz and Gelard, Maxence and Pandey, Priyanka and Laurent, Stefan and Laterre, Alexandre and Lang, Maren and others},
  journal={bioRxiv},
  pages={2024--03},
  year={2024},
  publisher={Cold Spring Harbor Laboratory}
}

If you have any questions or feedback on the code and models, please feel free to reach out to us.

Thank you for your interest in our work!

nucleotide-transformer's People

Contributors

biogeek avatar dallatt avatar e-trop avatar ranzentom avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nucleotide-transformer's Issues

Is the test split available?

Hi! Really cool work! I'm really interested in the downstream evaluations you guys did in the paper too. Is there information about how the test split was generated (ie to be able reproduce)? Thank you!

error while executing the Google Colab inference_nt notebook

I am trying to tailor the script for the Gene of my interest. Repeatedly, I am getting this error:
ValueError: Input length must be divisible by the 2 to the power of number of poolign layers.

I read your article as well, but nowhere I found the sequence length accepted by your model.
Could you kindly help?

Performance on RNA sequences

Thank you for outsourcing these amazing models. Since the name of the model is Nucleotide Transformer, I wonder if you have tested the model on RNA datasets. I am well aware that it's not originally trained on RNA sequences, and there is no U in the vocab. However, if we simply change 'U' to 'T', how it would be?
Bests,
Ai

Provide more examples and real world use cases

Great work, and love that everything is easily runnable in Colab. However, I am not clear on what are the most direct real world applications of this for bioinformatics or other omics workflows. Can you please provide some examples (worked notebooks would be nice, but simple text blurbs is fine too) so that I can better understand how this product could be used in the real world?

v2 checkpoints for TensorFlow

Hi,

thank you for your incredible work!

How much work would it be to generate model checkpoints for your updated v2 models that can be loaded and finetuned in TensorFlow and do you, by chance, have plans to support TensorFlow in the future?

Best,
Felix

Fix default `max_positions` in `get_pretrained_model`

Description

In the current version of the code, get_pretrained_model has max_positions=1024 by default but the models have been trained using max_positions=1000.

We should change default value as it will trigger a ValueError when calling the model.

Prediction using fine-tined model with LoRA

Hi,

We tuned NT model with LoRA as described in the tuning examples. The evaluation of tuning for our downstream task on the test.csv datasets seemed reasonable. However, we are getting these warning messages when attempting to preform prediction and the prediction results are very poor. Is it related to how we save the model ? or the prediction script ?

def tune(model_name, run_name, data_args_data_path, tuned_model_path, MODEL_MAX_LEN, NUM_EPOCHS=10, LR=6.1e-5, BATCH_SIZE=8, SEED=42, initate_process=False, is_fp16=True):
    
    num_labels = 3
    
    # Load the model
    model = AutoModelForSequenceClassification.from_pretrained(model_name,
                                                               num_labels=num_labels,
                                                               trust_remote_code=True)
    
    peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=1, lora_alpha= 32, lora_dropout=0.1, target_modules= ["query", "value"])
    lora_classifier = get_peft_model(model, peft_config) # transform the classifier into a peft model
    lora_classifier.print_trainable_parameters()
    
    # Load Tokenizer
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_name,
        model_max_length=MODEL_MAX_LEN,
        padding_side="right",
        use_fast=True,
        trust_remote_code=True)
    
    # tokenizer.eos_token = tokenizer.pad_token
    
    data_kmer = -1

    train_dataset = SupervisedDataset(tokenizer=tokenizer, 
                                        data_path=os.path.join(data_args_data_path, "train.csv"), 
                                        kmer=data_kmer)
    val_dataset = SupervisedDataset(tokenizer=tokenizer, 
                                        data_path=os.path.join(data_args_data_path, "dev.csv"), 
                                        kmer=data_kmer)
    test_dataset = SupervisedDataset(tokenizer=tokenizer, 
                                        data_path=os.path.join(data_args_data_path, "test.csv"), 
                                        kmer=data_kmer)
    
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

    tuning_args = TrainingArguments(
        output_dir=tuned_model_path,
        model_max_length=MODEL_MAX_LEN,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps= 1,
        per_device_eval_batch_size= 64,
        num_train_epochs= NUM_EPOCHS,
        learning_rate=LR,
        evaluation_strategy="steps",
        save_strategy="steps",
        logging_steps= 100000,
        save_steps=200,
        warmup_steps=50,
        load_best_model_at_end=True,  # Keep the best model according to the evaluation
        dataloader_drop_last=True,
        overwrite_output_dir=True,
        fp16=is_fp16,
        find_unused_parameters=False,
        remove_unused_columns=False,
        seed=SEED
    )
    
    trainer = transformers.Trainer(model = model,
                            tokenizer=tokenizer,
                            args=tuning_args,
                            compute_metrics=compute_metrics,
                            train_dataset=train_dataset,
                            eval_dataset=val_dataset,
                            data_collator=data_collator)
    
    train_results = trainer.train()
    
    os.makedirs(tuned_model_path, exist_ok=True)
    trainer.save_state()
    trainer.save_model(tuned_model_path)
    
    if tuning_args.eval_and_save_results:
        results_path = os.path.join(tuning_args.output_dir, "results", run_name)
        results = trainer.evaluate(eval_dataset=test_dataset)
        os.makedirs(results_path, exist_ok=True)
        with open(os.path.join(results_path, "eval_results.json"), "w") as f:
            json.dump(results, f)
def predict(pre_trained_model_name, tuned_model_path, data_path , data_kmer, MODEL_MAX_LEN, output_dir, BATCH_SIZE=64, SEED=42, initate_process=False,  is_fp16=True):
    
    # load tokenizer
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        pre_trained_model_name,
        model_max_length=MODEL_MAX_LEN,
        padding_side="right",
        use_fast=True,
        trust_remote_code=True,
    )
    
    predict_dataset = SupervisedDataset(tokenizer=tokenizer,
                                        data_path=os.path.join(data_path, "dev.csv"),
                                        kmer=data_kmer)
   
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
 
    # load model
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
    # model = PeftModel.from_pretrained(
        tuned_model_path,
        num_labels=predict_dataset.num_labels,
        trust_remote_code=True,
        # is_trainable=True
    )
    
    prediction_args = TrainingArguments(
        output_dir=output_dir,
        model_max_length=MODEL_MAX_LEN,
        per_device_eval_batch_size= BATCH_SIZE,
        evaluation_strategy="steps",
        save_strategy="steps",
        dataloader_drop_last=False,
        overwrite_output_dir=True,
        fp16=is_fp16,
        find_unused_parameters=False,
        remove_unused_columns=True,
        seed=SEED
    )
    
    # define trainer
    trainer = transformers.Trainer(model=model,
                                   tokenizer=tokenizer,
                                   args=prediction_args,
                                   compute_metrics=compute_metrics,
                                   train_dataset=None,
                                   eval_dataset=None,
                                   data_collator=data_collator)
       
    print("*** Predict ***")
    predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict_")
    
    for key, value in metrics.items():
        print("{}: {}".format(key, value))
WARNING:root:Perform single sequence classification...
Some weights of the model checkpoint at /Users/zingo/Library/Mobile Documents/com~apple~CloudDocs/Nvidia/code/DeepSNAP/tuned_models/NT_HumanRef500__Splice__RefSeq400_dummy_lora/ were not used when initializing EsmForSequenceClassification: ['esm.encoder.layer.15.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.19.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.17.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.5.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.15.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.1.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.16.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.2.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.6.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.20.attention.self.query.lora_B.default.weight', 'classifier.modules_to_save.default.dense.weight', 'esm.encoder.layer.8.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.13.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.16.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.12.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.7.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.4.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.12.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.18.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.13.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.1.attention.self.query.lora_B.default.weight', 'classifier.original_module.dense.weight', 'esm.encoder.layer.23.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.8.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.12.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.16.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.23.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.20.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.6.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.22.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.22.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.18.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.14.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.21.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.3.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.8.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.9.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.18.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.5.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.22.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.7.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.18.attention.self.query.lora_B.default.weight', 'classifier.original_module.out_proj.weight', 'esm.encoder.layer.14.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.20.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.4.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.19.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.17.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.2.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.9.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.1.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.5.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.0.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.11.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.22.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.6.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.7.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.5.attention.self.value.lora_A.default.weight', 'classifier.original_module.out_proj.bias', 'esm.encoder.layer.0.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.0.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.20.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.3.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.3.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.4.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.15.attention.self.query.lora_A.default.weight', 'classifier.modules_to_save.default.dense.bias', 'esm.encoder.layer.10.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.19.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.13.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.2.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.23.attention.self.value.lora_A.default.weight', 'classifier.modules_to_save.default.out_proj.bias', 'esm.encoder.layer.9.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.16.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.10.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.23.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.12.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.10.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.11.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.11.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.14.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.1.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.0.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.6.attention.self.value.lora_B.default.weight', 'classifier.modules_to_save.default.out_proj.weight', 'esm.encoder.layer.10.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.4.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.21.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.14.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.19.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.13.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.9.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.17.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.8.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.3.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.7.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.21.attention.self.value.lora_B.default.weight', 'esm.encoder.layer.2.attention.self.query.lora_B.default.weight', 'esm.encoder.layer.17.attention.self.query.lora_A.default.weight', 'esm.encoder.layer.15.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.21.attention.self.value.lora_A.default.weight', 'esm.encoder.layer.11.attention.self.query.lora_B.default.weight', 'classifier.original_module.dense.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at /Users/zingo/Library/Mobile Documents/com~apple~CloudDocs/Nvidia/code/DeepSNAP/tuned_models/NT_HumanRef500__Splice__RefSeq400_dummy_lora/ and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
*** Predict ***
100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 13/13 [01:23<00:00,  6.46s/it]
predict__loss: 1.3689093589782715
predict__accuracy: 0.01
predict__f1: 0.006600660066006601
predict__matthews_correlation: 0.0
predict__precision: 0.0033333333333333335
predict__recall: 0.3333333333333333
predict__runtime: 91.7778
predict__samples_per_second: 1.09
predict__steps_per_second: 0.142
*** Writing Predictions in probs.np ***
`` `

Tutorial no longer works with the v2 models

Hi, thanks for releasing these models.

I think with the updated v2 models, the API changed in get_pretrained_model and a few other places, so the tutorial no longer works for 500M_human_ref but it does for the new ones like 50M_multi_species_v2. Similarly, the outs dict now only has a logits key rather than embeddings_20.

See the below traceback:

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
)
forward_fn = hk.transform(forward_fn)
checkpoints/500M_human_ref/hyperparams.json

/root/.cache/nucleotide_transformer/500M_human_ref/hyperparams.json: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 815/815 [00:00<00:00, 4.57kB/s]
/root/.cache/nucleotide_transformer/500M_human_ref/ckpt.joblib: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.94G/1.94G [00:24<00:00, 80.0MB/s]

---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

[<ipython-input-5-244649b762aa>](https://localhost:8080/#) in <cell line: 2>()
      1 # Get pretrained model
----> 2 parameters, forward_fn, tokenizer, config = get_pretrained_model(
      3     model_name=model_name,
      4     embeddings_layers_to_save=(20,),
      5     attention_maps_to_save=((1, 4), (7, 18)),

[/usr/local/lib/python3.10/dist-packages/nucleotide_transformer/pretrained.py](https://localhost:8080/#) in get_pretrained_model(model_name, compute_dtype, param_dtype, output_dtype, embeddings_layers_to_save, attention_maps_to_save, max_positions)
    276         ffn_embed_dim=hyperparams["ffn_embed_dim"],
    277         num_layers=hyperparams["num_layers"],
--> 278         positional_embedding=hyperparams["positional_embedding"],
    279         add_bias_kv=add_bias_kv,
    280         add_bias_ffn=add_bias_ffn,

KeyError: 'positional_embedding'

model size mismatch

hello there,

recently I trying to load 'InstaDeepAI/nucleotide-transformer-v2-50m-multi-species' weight to generate sequence embbedings. However i met this issue while I trying to run the snippet code. the error shows blow:

RuntimeError: Error(s) in loading state_dict for EsmForMaskedLM:
size mismatch for esm.encoder.layer.0.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.1.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.2.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.3.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.4.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.5.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.6.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.7.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.8.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.9.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.10.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.11.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.12.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.13.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.14.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.15.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.16.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.17.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.18.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.19.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.20.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
size mismatch for esm.encoder.layer.21.intermediate.dense.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([2048, 512]).
You may consider adding ignore_mismatched_sizes=True in the model from_pretrained method.

the model weight and its config were download from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-100m-multi-species/tree/main, and were used in a offline mode.

NT nucleotide-level splice site prediction

Hi,
Thank you for sharing your paper and code! I was wondering if you are planning on making available the code for the NT fine-tuned for nucleotide-level splice site prediction featured in Figure 5d.
Thank you so much for considering my request!

where is the intermedia checkpoints

Dear authors of NT, I am very appreciate of your research. I am doing some benchmarking about your model, can you provide some intermedia checkpoints for us to benchmark?

Hugging Face Hub

Hello, thanks for the great paper!

Do you have any plans to upload the model to Hugging Face Hub in the near future?

Training process for Nucleotide Transformer Model

Thank you for sharing Nucleotide Transformer! I am trying to analyze RNA sequences using a custom dataset that is not compatible with the available pretrained models. I would like to train the model on my own dataset for a classification task, however I am having trouble implementing a training loop that is compatible with Nucleotide Transformer. Is it possible to expand the codebase to include examples of your training procedures and implemented metrics, etc. so a new pretrained model may be built?

Enable downloading model without tqdm

  • Tqdm is broken with some versions of jupyter notebook. Adding a verbose argument to be able to download the weights without using a progress bar to bypass this problem.

hyenadna evaluation

Thanks your great work!
i want to ask how do you use hyenadna to finetuning,(1)just like their github,python -m train wandb=null experiment=hg38/nucleotide_transformer dataset_name=enhancer dataset.max_length=500 model.layer.l_max=1026
or download the weights and use huggingface trainer like their colab?

Confusion about the max_positions value

Hello,
In the module of get_pretrained_model, you set the 'max_positions=32', I cannot understand the meaning of this value equal to 32, could you please tell me why set the value in it?
Thanks

The length of input sequence for SegmentNT

Hello Team,

I'm trying to use your model to predict splice sites on custom sequences. Would you please share if there's any limit for input sequence, such as length and context. Like for SpliceAI, it needs 5000bp contexts on each side, do you have any requirement?

Thanks
Zitong

Inquiry Regarding Details of Section A.5.4

I am particularly intrigued by the experiments outlined in section A.5.4, which focuses on Functional Variant Prioritization.

I am particularly intrigued by the experiments outlined in section A.5.4, which focuses on Functional Variant Prioritization. As I attempt to replicate this specific experiment, I have encountered some challenges and would greatly appreciate additional details to aid in my efforts. Specifically, I am interested in the following aspects:

  1. Embedding Extraction:

Could you please clarify from which layer of the Transformer the embeddings are extracted?

  1. Similarity Calculation:

In the calculation of similarity, is it based solely on the embeddings of tokens that have undergone mutations, or does it encompass the similarity of embeddings for the entire sequence?

  1. Binary Similarity Threshold:

What threshold value is employed for binary similarity in the two-class classification? Understanding this threshold is crucial for my replication efforts.

I have observed that the similarity between sequences with severe mutations tends to be exceptionally high (exceeding 0.999). To gain a deeper understanding and enhance the reproducibility of this experiment, I would be grateful for any additional insights or details you could provide.

clarifying context length and readme

  • Remove the learned positional embeddings that have not been trained from the weights of the 500M parameters models.
  • Adapt the hyperparams.json files of these models
  • Add a section on the tokenization process in the README.txt

Fine-tune codes

Thank you so much for sharing the pre-trained model. I am eager to explore its capabilities in some specific tasks and would like to fine-tune the model accordingly. Could you please provide some example code to help guide me through this process? Your assistance is greatly appreciated!

Issues with embedding with 2B5 models

Hello,
I am trying to use the provided code to embed a sequence (I am simple copy-pasting the code from the Readme)
With the 2 smaller models I have no issue.
However, both of the 2B5 models throw the following errors (2 different errors)
2B5_1000G:

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
Input In [31], in <cell line: 7>()
      4 from nucleotide_transformer.pretrained import get_pretrained_model
      6 # Get pretrained model
----> 7 parameters, forward_fn, tokenizer, config = get_pretrained_model(
      8     model_name="2B5_1000G",
      9     mixed_precision=False,
     10     embeddings_layers_to_save=(20,),
     11     max_positions=32,
     12 )
     13 forward_fn = hk.transform(forward_fn)
     15 # Get data and tokenize it

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/nucleotide_transformer/pretrained.py:197, in get_pretrained_model(model_name, mixed_precision, embeddings_layers_to_save, attention_maps_to_save, max_positions)
    192     raise NotImplementedError(
    193         f"Unknown {model_name} model. " f"Supported models are {supported_models}"
    194     )
    196 # Download weights and hyperparams
--> 197 parameters, hyperparams = download_ckpt_and_hyperparams(model_name)
    199 tokenizer = FixedSizeNucleotidesKmersTokenizer(
    200     k_mers=hyperparams["k_for_kmers"],
    201     fixed_length=max_positions,
    202     prepend_cls_token=True,
    203 )
    205 # Get config

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/nucleotide_transformer/pretrained.py:99, in download_ckpt_and_hyperparams(model_name)
     96         hyperparams = json.load(f)
     98     with open(params_save_dir, "rb") as f:
---> 99         params = joblib.load(f)
    101     return params, hyperparams
    103 else:

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:648, in load(filename, mmap_mode)
    646     filename = getattr(fobj, 'name', '')
    647     with _read_fileobject(fobj, filename, mmap_mode) as fobj:
--> 648         obj = _unpickle(fobj)
    649 else:
    650     with open(filename, 'rb') as f:

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:577, in _unpickle(fobj, filename, mmap_mode)
    575 obj = None
    576 try:
--> 577     obj = unpickler.load()
    578     if unpickler.compat_mode:
    579         warnings.warn("The file '%s' has been generated with a "
    580                       "joblib version less than 0.10. "
    581                       "Please regenerate this pickle file."
    582                       % filename,
    583                       DeprecationWarning, stacklevel=3)

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/pickle.py:1213, in _Unpickler.load(self)
   1211             raise EOFError
   1212         assert isinstance(key, bytes_types)
-> 1213         dispatch[key[0]](self)
   1214 except _Stop as stopinst:
   1215     return stopinst.value

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:415, in NumpyUnpickler.load_build(self)
    413 if isinstance(array_wrapper, NDArrayWrapper):
    414     self.compat_mode = True
--> 415 self.stack.append(array_wrapper.read(self))

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:252, in NumpyArrayWrapper.read(self, unpickler)
    250     array = self.read_mmap(unpickler)
    251 else:
--> 252     array = self.read_array(unpickler)
    254 # Manage array subclass case
    255 if (hasattr(array, '__array_prepare__') and
    256     self.subclass not in (unpickler.np.ndarray,
    257                           unpickler.np.memmap)):
    258     # We need to reconstruct another subclass

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:177, in NumpyArrayWrapper.read_array(self, unpickler)
    175 read_count = min(max_read_count, count - i)
    176 read_size = int(read_count * self.dtype.itemsize)
--> 177 data = _read_bytes(unpickler.file_handle,
    178                    read_size, "array data")
    179 array[i:i + read_count] = \
    180     unpickler.np.frombuffer(data, dtype=self.dtype,
    181                             count=read_count)
    182 del data

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle_utils.py:243, in _read_bytes(fp, size, error_template)
    238 while True:
    239     # io files (default in python3) return None or raise on
    240     # would-block, python2 file will truncate, probably nothing can be
    241     # done about that.  note that regular files can't be non-blocking
    242     try:
--> 243         r = fp.read(size - len(data))
    244         data += r
    245         if len(r) == 0 or len(data) == size:

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/compressor.py:464, in BinaryZlibFile.readinto(self, b)
    459 """Read up to len(b) bytes into b.
    460 
    461 Returns the number of bytes read (0 for EOF).
    462 """
    463 with self._lock:
--> 464     return io.BufferedIOBase.readinto(self, b)

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/compressor.py:456, in BinaryZlibFile.read(self, size)
    454     return self._read_all()
    455 else:
--> 456     return self._read_block(size)

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/compressor.py:429, in BinaryZlibFile._read_block(self, n_bytes, return_data)
    426 self._buffer_offset = 0
    428 blocks = []
--> 429 while n_bytes > 0 and self._fill_buffer():
    430     if n_bytes < len(self._buffer):
    431         data = self._buffer[:n_bytes]

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/compressor.py:393, in BinaryZlibFile._fill_buffer(self)
    391         return False
    392     else:
--> 393         self._buffer = self._decompressor.decompress(rawblock)
    394     self._buffer_offset = 0
    395 return True

error: Error -3 while decompressing data: invalid code lengths set

2B5_multi_species:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [30], in <cell line: 7>()
      4 from nucleotide_transformer.pretrained import get_pretrained_model
      6 # Get pretrained model
----> 7 parameters, forward_fn, tokenizer, config = get_pretrained_model(
      8     model_name="2B5_multi_species",
      9     mixed_precision=False,
     10     embeddings_layers_to_save=(20,),
     11     max_positions=32,
     12 )
     13 forward_fn = hk.transform(forward_fn)
     15 # Get data and tokenize it

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/nucleotide_transformer/pretrained.py:197, in get_pretrained_model(model_name, mixed_precision, embeddings_layers_to_save, attention_maps_to_save, max_positions)
    192     raise NotImplementedError(
    193         f"Unknown {model_name} model. " f"Supported models are {supported_models}"
    194     )
    196 # Download weights and hyperparams
--> 197 parameters, hyperparams = download_ckpt_and_hyperparams(model_name)
    199 tokenizer = FixedSizeNucleotidesKmersTokenizer(
    200     k_mers=hyperparams["k_for_kmers"],
    201     fixed_length=max_positions,
    202     prepend_cls_token=True,
    203 )
    205 # Get config

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/nucleotide_transformer/pretrained.py:99, in download_ckpt_and_hyperparams(model_name)
     96         hyperparams = json.load(f)
     98     with open(params_save_dir, "rb") as f:
---> 99         params = joblib.load(f)
    101     return params, hyperparams
    103 else:

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:648, in load(filename, mmap_mode)
    646     filename = getattr(fobj, 'name', '')
    647     with _read_fileobject(fobj, filename, mmap_mode) as fobj:
--> 648         obj = _unpickle(fobj)
    649 else:
    650     with open(filename, 'rb') as f:

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/site-packages/joblib/numpy_pickle.py:577, in _unpickle(fobj, filename, mmap_mode)
    575 obj = None
    576 try:
--> 577     obj = unpickler.load()
    578     if unpickler.compat_mode:
    579         warnings.warn("The file '%s' has been generated with a "
    580                       "joblib version less than 0.10. "
    581                       "Please regenerate this pickle file."
    582                       % filename,
    583                       DeprecationWarning, stacklevel=3)

File ~/lib/software/miniconda3/envs/torch-p310/lib/python3.10/pickle.py:1213, in _Unpickler.load(self)
   1211             raise EOFError
   1212         assert isinstance(key, bytes_types)
-> 1213         dispatch[key[0]](self)
   1214 except _Stop as stopinst:
   1215     return stopinst.value

KeyError: 188

2 sequences as input

Thank you so much for sharing the pre-trained model! I am eager to explore its capabilities in some specific tasks and would like to fine-tune the model accordingly. Is it possible to input 2 sequences as input to perform a classification task?
Thank you!

Bug: jax version issues on TPU Colab

Context

Using the example Colab, running on a TPU instance, user get the following error running the second code chunk:

image

Attempted Fixes

Updating the jax version to solve the issue

!pip install -U jax jaxlib

This now leads to a new error when running the second code chunk:

image

Suggested Fix

Delete the following lines from the second code chunk of the Google Colab example:

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

This will avoid the issue with jax and TPU conflicts. This reduces functionality of the notebook by removing TPUs from consideration, however it will reduce user friction when using the notebook by removing the jax/TPU issues.

Another notebook demonstrating full TPU functionality can then be developed once a proper fix to the jax issue is found, however from the error message I encountered it looks like that might be difficult in the Colab ecosystem.

how to deal with the seq position of reference genome?

Hi dallatt,

Thanks for your work and providing such a great tool.

I would like to know how you process the sequence position information of the reference genome. I saw in the article that during the data preparation stage, the mutation sequence in the individual from corresponding position was used to replace the tokens of the reference seq. I don't know how this step is implemented because I don't see the input related to position information in your codes.

image

Looking forward to your reply, thanks!

Pre-training on plant genomes

Hi,

The multi-species nucleotide transformer model appears to be very promising! I'm curious, do you have any plans to pre-train the nucleotide transformer on plant species?

Can I simply take 'CLS' token as my sequence representations?

Hello there,
I attempted to extract representations from the nucleotide transformer, particularly utilizing the 250 million multi-species model. Is there a suggested method for retrieving representations from embeddings, or would it be more effective to use the CLS token as a representation for my sequences?
To provide more context, these representations I'm seeking to extract are intended as initial embeddings for downstream tasks. The sequence lengths I'm working with vary significantly, ranging from 10 base pairs to several thousand base pairs.

Thanks in advance!

Extend SegmentNT notebook

  • Add the reproduction of Fig. 1 in the SegmentNT notebook
  • Change the pre-commit so that the notebooks outputs are not cleaned

Clarifications about context size

Could you clarify what's the recommended context size for the models?

  • The default max_positions is 1024.
  • The paper mentions maximum of 1000 tokens.
  • The paper mentions 6kb sequences, would that mean 1000 6-mer tokens + 1 CLS token = 1001 tokens?

Thank you!

Fix v2 api

Handling the v2 checkpoints broke the API for v1 models. This branch aims at fixing this

Cannot download the pre-training weight

Thanks for your work. It seems that we cannot download the pre-training weights with the example codes:

/home/tl688/.cache/nucleotide_transformer/1B_agro_nt/hyperparams.json: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.24k/1.24k [00:00<00:00, 6.99kB/s]
/home/tl688/.cache/nucleotide_transformer/1B_agro_nt/ckpt.joblib:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

The second process does not change for a long time. Thanks a lot.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.