Giter Site home page Giter Site logo

roim1998 / apt Goto Github PK

View Code? Open in Web Editor NEW
19.0 3.0 0.0 4.17 MB

[ICML'24 Oral] APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference

License: MIT License

Python 83.90% Shell 16.10%
bert efficient-deep-learning llama2 llm llm-finetuning peft peft-fine-tuning-llm pruning roberta t5

apt's Introduction

APT

APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference

News

  • [2024/01/22] Our paper is now on arXiv! Check it out here.
  • [2024/05/01] APT is accepted by ICML 2024 as oral presentation!

Overview

The APT framework We propose APT, a methodology that Adaptively selects model parameters for Pruning and fine-Tuning. APT combines the benefits of PEFT and structured pruning to make fine-tuning and inference more efficient.

How APT Works

Our intuition is that pretrained language model (LM) parameters contain general knowledge, but their importance to downstream tasks varies. Therefore, we can remove the parameters irrelevant to the fine-tuning task in the early training stage. Early-removing these parameters improves training and inference efficiency while not substantially hurting model accuracy. Meanwhile, continuously adding more parameters for fine-tuning can improve LM performance because task-specific skills live in a subset of LM parameters.

Based on this setup, we find that using self-distillation where the main parameters between the teacher and student models are shared can vasly prune small LMs with high end-task performance retained. Meanwhile, considering in-block outliers by calculating kurtosis when pruning large LMs before training can accurately prune them with less training memory footprint.

Main Results

RoBERTa-base experiment results:

Method MNLI SST2 SQuAD v2 Train Time Train Mem. Inf Time Inf Mem.
FT 87.6 94.8 82.9 100.0% 100.0% 100.0% 100.0%
LoRA 87.5 95.1 83.0 2137.0% 60.5% 100.0% 100.0%
LoRA+Prune 84.0 93.0 79.2 5128.3% 60.5% 38.0% 75.1%
Prune+Distill 87.3 94.5 - 1495.3% 168.5% 38.6% 79.2%
LoRA+Prune+Distill 84.2 91.9 - 6534.6% 141.4% 39.4% 82.3%
APT 86.4 94.5 81.8 592.1% 70.1% 41.3% 78.1%

T5-base experiment results:

Method MNLI SST2 CNN/DM Train Time Train Mem. Inf Time Inf Mem.
FT 87.1 95.2 42.1/20.3/39.4 100.0% 100.0% 100.0% 100.0%
LoRA 87.0 95.0 38.7/17.2/36.0 255.5% 62.0% 100.0% 100.0%
LoRA+Prune 80.9 92.3 36.7/15.7/33.9 4523.5% 62.0% 47.1% 73.4%
APT 87.0 95.0 38.6/17.0/35.8 484.7% 73.9% 74.6% 81.5%

LLaMA-7B experiment results:

Method ARC HellaSwag MMLU TruthfulQA Avg. Train Time Train Mem. Inf Time Inf Mem.
LLaMA 2 7B 53.1 77.7 43.8 39.0 53.4 - - - -
LoRA 55.6 79.3 46.9 49.9 57.9 100.0% 100.0% 100.0% 100.0%
LoRA+Prune 46.8 65.2 23.9 46.2 45.5 180.9% 100.0% 115.5% 68.9%
LLMPruner 39.2 67.0 24.9 40.6 42.9 86.9% 253.6% 114.8% 74.2%
APT 45.4 71.1 36.9 46.6 50.0 106.0% 75.8% 117.0% 67.2%

Setup

Installation

conda env create -f environment.yml
conda activate apt

Training

For finetuning RoBERTa-base models with APT, please run:

bash scripts/adaptpruning/roberta_base_sst2_momentum.sh

For finetuning T5-base models with APT, please run:

bash scripts/adaptpruning/t5_base_lm_adapt_cnndm_momentum.sh

For finetuning LLaMA2 models on Alpaca with APT, please run:

bash scripts/adaptpruning/llama_2_7b_alpaca_gpt4.sh

Citation

If you use this code or our tuned models, please cite our paper:

@misc{zhao2024apt,
      title={APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference}, 
      author={Bowen Zhao and Hannaneh Hajishirzi and Qingqing Cao},
      year={2024},
      eprint={2401.12200},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Acknowledgements

This project uses modified code from the following projects:

  • CoFiPruning: developed by Princeton-nlp. Model backbone codes reused for pruning BERT and RoBERTa. See models/modeling_bert.py and models/modeling_roberta.py.
  • Mask Tuning: developed by Woosuk Kwon. We use the codes as a baseline to prune the post-training LMs. See prune/fisher.py.

apt's People

Contributors

roim1998 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

apt's Issues

How to get the results shown in Table 3?

Hi,
I have followed the command bash scripts/adaptpruning/llama_2_7b_alpaca_gpt4.sh with the alpaca_gpt data downloaded from https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json. After training the model, how do I get the results for APT given in Table 3? Is there a separate evaluation script?

Training OOM of llama2-7b base on A100 (40GB)

Hi author, thanks for your nice job! I try to use the default config of "scripts/adaptpruning/llama_2_7b_alpaca_gpt4.sh" to run on a GPU A100 (40GB) but code will OOM. Paper mentioned on Sec."5.4. Main Results", "APT costs less than 24GB of memory when pruning 30% parameters in LLaMA2-7B models before tuning,...", from what I understand this means that the training code can be run on devices with less than 24GB, why i still got the OOM when training on a more than 24GB device.

Is that any settings to change for a limit training without OOM?

P.S.
I have tried to modify the "pruning_batch_size" as 1, and "training_batch_size" as 1, but it didn't work.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.