Giter Site home page Giter Site logo

amelie-schreiber / esm2_loras Goto Github PK

View Code? Open in Web Editor NEW
3.0 1.0 0.0 90 KB

Trying to train LoRAs for ESM-2

Home Page: https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites

Jupyter Notebook 67.73% Python 32.27%
esm lora low-rank-adaptation

esm2_loras's Introduction

esm2_loras

This is an attempt at training a Low Rank Adaptation (LoRA) for the protein language model ESM-2 for a token classification task. In particular, we attempt to train an RNA binding site predictor. There are still some issues to work out and any feedback or advice would be much appreciated. This code is for a small model so it should perform wandb sweeps for hyperparameter search in a reasonable amount of time on almost any GPU. You can easily swap out for larger models though if you want.

Model Weights and Config

The model itself

"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"

can be found on Hugging Face here.

Setting up this repo

To set up the the conda environment, clone the repo and run:

conda env create -f environment.yml

Then run:

conda activate lora_esm_2

To train the model run:

from lora_esm2_script import train_protein_model

train_protein_model()

To use, try running:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random



# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)

# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"

# Tokenize the new sequence
inputs = loaded_tokenizer(new_protein_sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")

# Make predictions
with torch.no_grad():
    outputs = loaded_model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)

# Print logits for debugging
print("Logits:", logits)

# Convert predictions to a list
predicted_labels = predictions.squeeze().tolist()

# Get input IDs to identify padding and special tokens
input_ids = inputs['input_ids'].squeeze().tolist()

# Define a set of token IDs that correspond to special tokens
special_tokens_ids = {loaded_tokenizer.cls_token_id, loaded_tokenizer.pad_token_id, loaded_tokenizer.eos_token_id}

# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [label for label, token_id in zip(predicted_labels, input_ids) if token_id not in special_tokens_ids]

print("Predicted binding sites:", binding_sites)

esm2_loras's People

Contributors

amelie-schreiber avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

esm2_loras's Issues

ValueError: EsmForTokenClassification does not support gradient checkpointing

Hi Amelie, I am looking to finetune ESM model on my small downstream task. I am following your article https://huggingface.co/blog/AmelieSchreiber/esm2-ptm.

I am getting the following error on gradient_checkpointing_enable() function that you might already have a solution/suggestions.

    model.gradient_checkpointing_enable() # SP commented
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sureshp/anaconda3/envs/qlora/lib/python3.11/site-packages/transformers/modeling_utils.py", line 1631, in gradient_checkpointing_enable
    raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
ValueError: EsmForTokenClassification does not support gradient checkpointing.

Thanks.

Incorrectly loading trained weights

Hi, this is Kiarash.

So you are incorrectly loading your already trained weights. LoRA weights are not the same as other HF weights, you need to load them on top of your base model. This was already described in the link I sent you, but for your convenience, I am sharing the exact lines of code you would need to run.

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel

# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.