Giter Site home page Giter Site logo

bitdistiller's Introduction

[ACL 2024] BitDistiller: Unleashing the Potential of Sub-4-Bit LLMs via Self-Distillation [paper]

Implementing efficient sub-4-bit weight quantization (3 / 2 bits) in LLMs through advanced QAT-based Self-Distillation techniques.

overview

Comparing general language tasks with other methods

overview

Comparing reasoning benchmarks with other methods

overview

Example on 2-bit inference of a Domain-specific LLM (MetaMath)

gif

News

  • [2024/05] 🔥 BitDistiller has been accepted to ACL main 2024!

Contents

  1. Setup
  2. Running
  3. Evaluation
  4. Inferencce

1. Setup

  • python 3.9, pytorch >= 1.13

  • pip install -r requirement.txt

    (You may need to change the version of transformers according to the model config)

2. Running

Our results is running by following 3 steps:

2.1. Asymmetric Quantization

  • Determine the type of quantization: use nf3 for 3 bits and int for 2 bits. Set w_bit and quant_type accordingly.

  • Perform clipping before training and save the clipping values using dump_clip (see quantization/autoclip.py).

This step can match or surpass the low-bit PTQ quantization results of GPTQ and AWQ.

2.2. Generating Teacher Data

  • For QAT, create data using the Teacher Model (BF16). The data varies depending on the model (see data/generation).

2.3. KD-base QAT

  • Detailed procedure available in train/

Example Srcipts

LLaMA-2
  1. Get the Clipping result

    cd BitDistiller/quantization
    
    CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset pile --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/hf-llama2-7b/int2-g128.pt
  2. Get the Teacher Generation Data (Using vllm would be much faster)

    # vllm
    python generate_vllm.py --base_model <model_path> --dataset_name wikitext --out_path ./datasets/hf-llama-2-7b/ --max_sample 3000
    
    python generate_vllm.py --base_model <model_path> --dataset_name alpaca --out_path ./datasets/hf-llama-2-7b/ --max_sample 5000
    
    # change to path in .py
    python mix_data.py
    # torchrun
    cd BitDistiller/data/generation
    
    bash generate.sh <model_path> wikitext ../datasets/hf-llama-2-7b/ 16 3000
    
    bash generate.sh <model_path> alpaca ../datasets/hf-llama-2-7b/ 16 5000
    
    # change to path in .py
    python mix_data.py
  3. Run KD-base QAT

    # Specify the pre-trained model path
    # Specify the num_gpus and batch_size according to your GPU devices
    # Specify the clipping cache path to the --clip
    
    cd train
    
    bash train.sh ../data/datasets/hf-llama-2-7b/mix_wiki_alpaca_8000.json ./ckpts/hf-llama-2-7b/int2-g128/ ./logs/hf-llama-2-7b/int2-g128/ 4
WizardCoder
  1. Get the Clipping result

    cd BitDistiller/quantization
    
    CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset code --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/WizardCoder-7B/int2-g128.pt
  2. Get the Teacher Generation Data

    # vllm
    python generate_vllm.py --base_model <model_path> --dataset_name code --out_path ./datasets/WizardCoder-7b/ --max_sample 3000
    cd BitDistiller/data/generation
    
    bash generate.sh /root/WizardCoder-Python-7B/ code ../datasets/WizardCoder-7b/ 16 3000
  3. Run KD-base QAT

    # Specify the pre-trained model path
    # Specify the num_gpus and batch_size according to your GPU devices
    # Specify the clipping cache path to the --clip
    
    cd train
    
    bash train.sh ../data/datasets/WizardCoder-7b/code_T0.7_N1024_S42_3000.json ./ckpts/WizardCoder-7b/int2-g128/ ./logs/WizardCoder-7b/int2-g128/ 2
MetaMath
  1. Get the Clipping result

    cd BitDistiller/quantization
    
    CUDA_VISIBLE_DEVICES=0 python autoclip.py --model_path <model_path> --calib_dataset gsm8k --quant_type int --w_bit 2 --q_group_size 128 --run_clip --dump_clip ./clip_cache/MetaMath-7B/int2-g128.pt
  2. Get the Teacher Generation Data

    # vllm
    python generate_vllm.py --base_model <model_path> --dataset_name math --out_path ./datasets/MetaMath-7B/ --max_sample 3000
    cd BitDistiller/data/generation
    
    bash generate.sh /root/MetaMath-7B-V1.0/ math ../datasets/MetaMath-7B/ 16 3000
  3. Run KD-base QAT

    # Specify the pre-trained model path
    # Specify the num_gpus and batch_size according to your GPU devices
    # Specify the clipping cache path to the --clip
    
    cd train
    
    bash train.sh ../data/datasets/MetaMath-7B/math_T0.7_N1024_S42_3000.json ./ckpts/MetaMath-7b/int2-g128/ ./logs/MetaMath-7b/int2-g128/ 2

3. Evaluation

Example Srcipts

LLaMA-2
  • Test PPL on WikiText-2
    cd test/general
    
    python wiki_ppl.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --quant_type int --bits 2 --group_size 128
  • Test MMLU
    CUDA_VISIBLE_DEVICES=0 python llm_eval.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --eval_tasks hendrycksTest-* --test_set --bits 2 --group_size 128 --quant_type int --num_fewshot 5
  • Test Common-sense QA Tasks
    CUDA_VISIBLE_DEVICES=0 python llm_eval.py --model ../../train/ckpts/hf-llama-2-7b/int2-g128/checkpoint-200/ --eval_tasks arc_challenge,winogrande,hellaswag,piqa --test_set --bits 2 --group_size 128 --quant_type int --num_fewshot 0 
WizardCoder
  • Install the environment according to the instructions of HumanEval,

  • Example script:

    cd test/humaneval
    bash gen_preds.sh [checkpoint_path] ./preds/7b/int2-g128/
MetaMath
  • Example script:

    cd test/gsm8k
    bash test.sh ../../train/ckpts/MetaMath-7b/int2-g128/ ./preds/7b/int2-g128/

4. Inference

Please see inference/

Reference

If you find BitDistiller useful or relevant to your research, please kindly cite our paper:

@misc{du2024bitdistiller,
      title={BitDistiller: Unleashing the Potential of Sub-4-Bit LLMs via Self-Distillation}, 
      author={Dayou Du and Yijia Zhang and Shijie Cao and Jiaqi Guo and Ting Cao and Xiaowen Chu and Ningyi Xu},
      year={2024},
      eprint={2402.10631},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

bitdistiller's People

Contributors

dd-duda avatar lenan22 avatar

Stargazers

 avatar Hui Chen avatar Daxiong avatar liujingcs avatar  avatar  avatar 王铭(Wang Ming) avatar Yintel avatar  avatar Zhichen Zeng avatar  avatar  avatar LIM Woosang avatar BYEONGHO YU avatar  avatar  avatar Xiang LIU avatar Marcia.7 avatar Weihao Cui avatar Lingxiao Ma avatar Ziming Miao avatar Sunny Gonnabathula avatar JIANG Zijun avatar Hongyuan Liu avatar Chao Zeng avatar Jiajun Liu avatar  avatar  avatar  avatar 爱可可-爱生活 avatar Mengzhao Chen avatar  avatar  avatar Egqawkq avatar  avatar DaHoon Park avatar Jeff Carpenter avatar  avatar inisis avatar  avatar Jinyu Bai avatar Xiangrui Yu avatar CHEN Yuhan avatar Peyton avatar  avatar 苏铄淼 avatar Eric Zhao avatar  avatar Yiqian He avatar  avatar FGG avatar Minsoo Kim avatar Shijie Cao avatar Harahan avatar  avatar 巩固 avatar  avatar

Watchers

Shijie Cao avatar Kostas Georgiou avatar  avatar

bitdistiller's Issues

CUDA out of memory

Hello, thank you for your contributions!
CUDA OUT OF MEMORY occurs when running asymmetric clipping (LLaMA-2-7B) on a single A800-80G where in your paper a single A100 is enough. Why this happened?

About "inference"

I get "No module named tinychat". Do I need to install more packages?

Question about the dataset size

Hello, thanks for your great work.
I have some question about the dataset used in table6 of BitDistiller:
image
LLM-QAT is datafree method, and from A.3 I get to know that BitDistiller only use "a small portion of the involved
datasets", so what is 100k and 2k data mean in table6?

Question about 2-bit kerrel.

Hello,

Thanks for your outstanding works. I am curious about the 2bit inference speedup.

I want to know which quantization kernel are used during the speed testing.

Is the kernel directly from GPTQ triton kernel, or have you made some improvements?

4-bit or higher?

Have you conducted evaluations on the effects of quantization of 4 bits or higher?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.