Giter Site home page Giter Site logo

santassun / chatglm-finetune-lora Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lich99/chatglm-finetune-lora

0.0 0.0 0.0 12.42 MB

Code for fintune ChatGLM-6b using low-rank adaptation (LoRA)

License: Apache License 2.0

Python 44.47% Jupyter Notebook 55.53%

chatglm-finetune-lora's Introduction

ChatGLM-finetune-LoRA

This repository contains code for fintune ChatGLM-6b using low-rank adaptation (LoRA).

We also provide a finetuned weight.

The minimum required GPU memory is 24G, RTX3090 is enough for training.

  • 2022/3/28: Optimized code structure, more simple and clear. Add training instruction.
  • 2022/3/24: Support Multi-GPU training, DeepSpeed, Batch collate. Using accelerate to launch train.py

Easy to use

import loralib as lora
import lora_utils.insert_lora
import dataset.GLM as GLM_Data
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel

device = 'cuda'
checkpoint = "THUDM/chatglm-6b"


# load model
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True)

# get LoRA model
lora_config = {
        'r': 32,
        'lora_alpha':32,
        'lora_dropout':0.1,
        'enable_lora':[True, True, True],
    }
model = lora_utils.insert_lora.get_lora_model(model, lora_config)
### trainable_params:22020096 (0.35%), non_trainable_params:6255206400

# get Dataloader
pairs = [{'prompt':'Hello!', 'completion':'Hi! This is ChatGLM.'}]
pairs_encoded = GLM_Data.encode_pairs(pairs, tokenizer)
train_dataset = GLM_Data.GLMDataset(pairs_encoded)
train_dataloader = DataLoader(dataset=train_dataset, collate_fn = GLM_Data.collate_fn, shuffle=True, batch_size=1)

# training
model.half().to(device)
batch = {k: v.to(device) for k, v in next(iter(train_dataloader)).items()}
outputs = model(**batch)
outputs.loss.backward()

Training

Using accelerate CLI tool to launch multiprocess / distributed training:

accelerate launch --config_file config/default_config.yaml train_new.py

Likes OpenAI's fintune API, the data should be in following structure:

[
    {'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
    {'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
    ...,
    {'prompt': <enter the prompt here (can be instrcution)>, 'completion': <the expectation completion>},
]

It is a list of prompt-completion pairs.

Stanford Alpaca's Dataset

Here we use the Stanford Alpaca's Dataset as an example for fine-tuning. We also provide a finetuned weight.

example line:

{'prompt': 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nClassify the movie genres from the given context.\n\n### Input:\nThis movie tells the story of two brothers who were both born with magical powers.\n\n### Response:', 'completion': 'Fantasy'}

Training for Stanford Alpaca's Dataset should within 30min per epoch on 4*V100

You may observe a typical training loss curve: example_training_loss Note: vary with different dataset

LoRA

lora_config = {
        'r': 32,
        'lora_alpha':32,
        'lora_dropout':0.1,
        'enable_lora':[True, True, True],
    }

Using above LoRA config, we have trainable_params:22020096 (0.35%), non_trainable_params:6255206400

Save & Load

torch.save(lora.lora_state_dict(model), 'path to file you saved')
model.load_state_dict(torch.load('path to file you saved'), strict=False)

chatglm-finetune-lora's People

Contributors

lich99 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.