Giter Site home page Giter Site logo

longalign's Introduction

LongAlign: A Recipe for Long Context Alignment of LLMs

🤗 HF Repo • 📃 Paper

阅读中文版本

LongAlign is the first full recipe for LLM alignment on long context. We propose the LongAlign-10k dataset, containing 10,000 long instruction data of 8k-64k in length. We investigate on training strategies, namely packing (with loss weighting) and sorted batching, which are all implemented in our code. For real-world long context evaluation, we introduce LongBench-Chat that evaluates the instruction-following capability on queries of 10k-100k length.

🔍 Table of Contents

⚙️ Data Preparation

You can download and save the LongAlign-10k data through the Hugging Face datasets (🤗 HF Repo):

dataset = load_dataset('THUDM/LongAlign-10k')
for split, split_dataset in dataset.items():
    split_dataset.to_json("data/raw/long.jsonl")

The ShareGPT data can be downloaded from here. We refer to the open-instruct repository for the preprocesss of ShareGPT data. Please save the data file at data/raw/sharegpt.jsonl. You can use other data as a source for general instruction data, but please format your data as follows:

{
    "messages": [{"role": "user", "content": "..."}, 
                 {"role": "assistant", "content": "..."}, ...]
    }

🖥️ LongAlign Training

Environmental Setup

Install the requirements with pip: pip install -r requirements.txt. For Llama based models, we recommend using FlashAttention 2 for optimization and saving GPU memory. The relevant dependencies can be installed according to the code base of FlashAttention.

Data preprocessing

First, tokenize the raw text data using the tokenizer of the model. For example, when training ChatGLM:

python pre_tokenize.py --model chatglm --datanum 10k

The --datanum parameter here refers to the amount of long data you want in your mixed training dataset (our paper investigates on 0k, 5k, and 10k). The tokenized data will be saved under ./data/chatglm/10k.

For the packing and sorted batching strategies, we then organize the tokenized data for training:

python sort_and_group.py --group_size 8 --train_file ./data/chatglm/10k

You should set the --group_size parameter to the number of GPUs during training. We recommend using at least 8 80G GPUs for model training, otherwise the 64k length may incur memory overflow.

Model training

We provide training scripts under scripts/ for the ChatGLM3 and Llama-2 model series. Make sure to adjust --model_name_or_path, --train_file, and --output_dir to match your model path, data path, and output path. You should consider using a base model with at least 64k context window length. We release three base models with extended context windows of 64k: LongAlign-6B-64k-base, LongAlign-7B-64k-base, and LongAlign-13B-64k-base.

For packing training, please modify the attention calculation to support the 1D attention mask that marks the start and end position of each sequence in the pack, and the model forward function to support loss weighting during packing training. An example of such modifications for the ChatGLM3 model is provided in modeling_chatglm.py, in CoreAttention.forward and ChatGLMForConditionalGeneration.forward. You can directly use this file as the modeling file for ChatGLM packing training. We also provide the training code for Llama. To reproduce our results, please use modeling_llama.py as the modeling file. As suggested in the result our paper, we recommend packing+loss weighting for ChatGLM training and sorted batching for Llama.

Model deploying

We have released four chat models trained using LongAlign: LongAlign-6B-64k (based on ChatGLM3-6B), LongAlign-7B-64k (based on Llama-2-7B), LongAlign-13B-64k (based on Llama-2-13B), and ChatGLM3-6B-128k. Try the model to summarize our paper, or ask anything about it:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongAlign-6B-64k", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/LongAlign-6B-64k", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()
query = open("assets/paper.txt").read() + "\n\nPlease summarize the paper."
response, history = model.chat(tokenizer, query, history=[], max_new_tokens=512, temperature=1)
print(response)

For Llama-based models, we also provide a llama_flash_attn_monkey_patch.py for utilization of FlashAttention-2 to save memory for inference on long sequences.

All available models

Here is the full list of models we released:

Model HF Repo Description
LongAlign-6B-64k-base 🤗 HF Repo ChatGLM3-6B with an extended 64k context window
LongAlign-6B-64k 🤗 HF Repo Chat model by LongAlign training on LongAlign-6B-64k-base
LongAlign-7B-64k-base 🤗 HF Repo Llama-2-7B with an extended 64k context window
LongAlign-7B-64k 🤗 HF Repo Chat model by LongAlign training on LongAlign-7B-64k-base
LongAlign-13B-64k-base 🤗 HF Repo Llama-2-13B with an extended 64k context window
LongAlign-13B-64k 🤗 HF Repo Chat model by LongAlign training on LongAlign-13B-64k-base
ChatGLM3-6B-128k 🤗 HF Repo ChatGLM3-6B with a 128k context window

📊 Evaluation

LongBench-Chat evaluation

LongBench-Chat is the first benchmark for assessing long context alignment, featuring real user queries of 10k-100k in length. The dataset and evaluation code are available under LongBench_Chat/. Remember to configure your OpenAI API key in eval.py since we adopt GPT-4 as the evaluator. Run

python eval.py --model {model_path} --max_length {max_length}

model_path can either be your local model path or a Hugging Face model path. Here is the leaderboard on LongBench-Chat:

You are also welcome to submit your model's test predictions or results to us. We are planning to release a more formal leaderboard.

Needle-test evaluation

We also provide the code for evaluating HuggingFace models on the "Needle In A Haystack" test under Needle_test/. See its README.md for more information.

To reproduce our results on other benchmarks, we refer to the code in LongBench, FastChat, and lm-evaluation-harness for evaluating on LongBench, MT-Bench, and Open LLM Leaderboard tasks.

📝 Citation

If you find our work useful, please consider citing LongAlign:

@article{bai2024longalign,
  title={LongAlign: A Recipe for Long Context Alignment of Large Language Models},
  author={Yushi Bai, Xin Lv, Jiajie Zhang, Yuze He, Ji Qi, Lei Hou, Jie Tang, Yuxiao Dong, Juanzi Li},
  journal={arXiv preprint arXiv:2401.18058},
  year={2024}
}

longalign's People

Contributors

bys0318 avatar davidlvxin avatar neo-zhangjiajie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

longalign's Issues

Questions about ChatGLM3-6b-128k

Hello, I am very interested in your work, but I have a few questions below.

  1. How to further train a chat model like ChatGLM3-6b? As we know that they have templates, will the further training affect the instruction-following ability or any unstable issues?

  2. How to get ChatGLM3-6b-128k? From the paper I find that ChatGLM3-6b-64k is trained from ChatGLM3-32k with 10b further training data and SFT with longalign. Maybe I missed how to get ChatGLM3-6b-128k.

[BUG] 使用Langchain-Chatchat框架加载chatglm3-6-b-128k出现自问自答,停不下来的情况

问题描述 / Problem Description
使用chatglm-6-b-128k出现自问自答,停不下来的情况

微信图片_20240313110149

复现问题的步骤 / Steps to Reproduce
1.使用chatglm3-6b-128k
2.无论问答什么内容,都会出现自问自答情况。

预期的结果 / Expected Result
回答完毕后停止。

实际结果 / Actual Result
回答当前问题后,不停自问自答,无法停止。

环境信息 / Environment Information

  • langchain-ChatGLM 版本/commit 号:v0.2.10
  • 是否使用 Docker 部署(是/否):否
  • 使用的模型(ChatGLM2-6B / Qwen-7B 等):ChatGLM3-6B-128k
  • 使用的 Embedding 模型(moka-ai/m3e-base 等):m3e-base
  • 使用的向量库类型 (faiss / milvus / pg_vector 等): faiss
  • 操作系统及版本 / Operating system and version: linux
  • Python 版本 / Python version: 3.11
  • 其他相关环境信息 / Other relevant environment information:

附加信息 / Additional Information
想通过添加repetition_penalty来解决,但是未找到参数修改位置,希望可以指明。

关于Packing和 直接Batch的loss区别?

论文中指出Packing Loss和直接Batch Loss不一致,是基于这个公式:
image
即:以样本为粒度,算loss 先在样本内平均,再batch内平均,两步走。

基于我的认知,SFT训练中一般是以Token为粒度算最终的loss的,即 "target token loss 总和 / target token 总数",并非样本粒度。

我看了下你的代码实现,即modeling_llama.py文件中按直接Batch算,loss是 从 batch*seq 直接Flat成一个seq,还是直接以token为粒度计算的loss,并非样本粒度(即先在seq 求平均,再在batch求平均)
image

有两个问题讨论:

  1. SFT中loss 最后一步的平均, 究竟应该以Token为粒度 还是以样本为粒度?
  2. 如果以Token为粒度,我认为Packing和非Packing是等价的

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.